From dcd063636a33b0b677f0f1e301f7496735b9bd8e Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 26 Sep 2024 19:04:51 +0800 Subject: [PATCH] opt++ --- src/layer/arm/gemm_int8.h | 134 +++++++++++++++++-------------- src/layer/arm/gemm_int8_bf16s.h | 137 ++++++++++++++++---------------- 2 files changed, 143 insertions(+), 128 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 10045d2f722..fd8f56985b7 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -10137,7 +10137,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { if (c_elempack == 1) { - // TODO decompose 8x8 to 8x4 and 8x4 _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + 4); float32x4_t _c2 = vld1q_f32(pC + c_hstep); @@ -10146,53 +10145,61 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4); - float32x4_t _ca = vld1q_f32(pC + c_hstep * 5); - float32x4_t _cb = vld1q_f32(pC + c_hstep * 5 + 4); - float32x4_t _cc = vld1q_f32(pC + c_hstep * 6); - float32x4_t _cd = vld1q_f32(pC + c_hstep * 6 + 4); - float32x4_t _ce = vld1q_f32(pC + c_hstep * 7); - float32x4_t _cf = vld1q_f32(pC + c_hstep * 7 + 4); - transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7, _c8, _c9, _ca, _cb, _cc, _cd, _ce, _cf); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c2); - _f2 = vaddq_f32(_f2, _c4); - _f3 = vaddq_f32(_f3, _c6); - _f4 = vaddq_f32(_f4, _c8); - _f5 = vaddq_f32(_f5, _ca); - _f6 = vaddq_f32(_f6, _cc); - _f7 = vaddq_f32(_f7, _ce); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c3); - _fa = vaddq_f32(_fa, _c5); - _fb = vaddq_f32(_fb, _c7); - _fc = vaddq_f32(_fc, _c9); - _fd = vaddq_f32(_fd, _cb); - _fe = vaddq_f32(_fe, _cd); - _ff = vaddq_f32(_ff, _cf); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c2, _beta); - _f2 = vmlaq_f32(_f2, _c4, _beta); - _f3 = vmlaq_f32(_f3, _c6, _beta); - _f4 = vmlaq_f32(_f4, _c8, _beta); - _f5 = vmlaq_f32(_f5, _ca, _beta); - _f6 = vmlaq_f32(_f6, _cc, _beta); - _f7 = vmlaq_f32(_f7, _ce, _beta); - _f8 = vmlaq_f32(_f8, _c1, _beta); - _f9 = vmlaq_f32(_f9, _c3, _beta); - _fa = vmlaq_f32(_fa, _c5, _beta); - _fb = vmlaq_f32(_fb, _c7, _beta); - _fc = vmlaq_f32(_fc, _c9, _beta); - _fd = vmlaq_f32(_fd, _cb, _beta); - _fe = vmlaq_f32(_fe, _cd, _beta); - _ff = vmlaq_f32(_ff, _cf, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 5); + _c3 = vld1q_f32(pC + c_hstep * 5 + 4); + _c4 = vld1q_f32(pC + c_hstep * 6); + _c5 = vld1q_f32(pC + c_hstep * 6 + 4); + _c6 = vld1q_f32(pC + c_hstep * 7); + _c7 = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } pC += 8; } @@ -10451,38 +10458,45 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { if (c_elempack == 1) { - // TODO decompose 4x8 to 4x4 and 4x4 _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + c_hstep); float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); - transpose4x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + transpose4x4_ps(_c0, _c1, _c2, _c3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c2); - _f2 = vaddq_f32(_f2, _c4); - _f3 = vaddq_f32(_f3, _c6); - _f4 = vaddq_f32(_f4, _c1); - _f5 = vaddq_f32(_f5, _c3); - _f6 = vaddq_f32(_f6, _c5); - _f7 = vaddq_f32(_f7, _c7); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c2, _beta); - _f2 = vmlaq_f32(_f2, _c4, _beta); - _f3 = vmlaq_f32(_f3, _c6, _beta); - _f4 = vmlaq_f32(_f4, _c1, _beta); - _f5 = vmlaq_f32(_f5, _c3, _beta); - _f6 = vmlaq_f32(_f6, _c5, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } pC += 4; } diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 3f58bd64438..989cdf6866a 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -5870,19 +5870,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } else // if (c_elempack == 4) { - // TODO optimize uint16x8_t _cc0 = vld1q_u16(pC); uint16x8_t _cc1 = vld1q_u16(pC + c_hstep * 4); - _c0 = bfloat2float(vget_low_u16(_cc0)); - _c1 = bfloat2float(vget_high_u16(_cc0)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_cc1)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_cc1)); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - float32x4x2_t _c23 = vzipq_f32(_c2, _c3); - _c0 = vcombine_f32(vget_low_f32(_c01.val[0]), vget_low_f32(_c01.val[1])); - _c1 = vcombine_f32(vget_high_f32(_c01.val[0]), vget_high_f32(_c01.val[1])); - _c2 = vcombine_f32(vget_low_f32(_c23.val[0]), vget_low_f32(_c23.val[1])); - _c3 = vcombine_f32(vget_high_f32(_c23.val[0]), vget_high_f32(_c23.val[1])); + uint16x8x2_t _cc = vzipq_u16(vcombine_u16(vget_low_u16(_cc0), vget_low_u16(_cc1)), vcombine_u16(vget_high_u16(_cc0), vget_high_u16(_cc1))); + _c0 = bfloat2float(vget_low_u16(_cc.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc.val[0])); + _c2 = bfloat2float(vget_low_u16(_cc.val[1])); + _c3 = bfloat2float(vget_high_u16(_cc.val[1])); pC += 8; } if (beta == 1.f) @@ -8703,16 +8697,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (c_elempack == 1) { - // TODO decompose 8x8 to 8x4 and 8x4 uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); - transpose8x8_u16(_c01, _c23, _c45, _c67, _c89, _cab, _ccd, _cef); + transpose8x4_u16(_c01, _c23, _c45, _c67); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -8721,52 +8710,64 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c2); - _f2 = vaddq_f32(_f2, _c4); - _f3 = vaddq_f32(_f3, _c6); - _f4 = vaddq_f32(_f4, _c8); - _f5 = vaddq_f32(_f5, _ca); - _f6 = vaddq_f32(_f6, _cc); - _f7 = vaddq_f32(_f7, _ce); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c3); - _fa = vaddq_f32(_fa, _c5); - _fb = vaddq_f32(_fb, _c7); - _fc = vaddq_f32(_fc, _c9); - _fd = vaddq_f32(_fd, _cb); - _fe = vaddq_f32(_fe, _cd); - _ff = vaddq_f32(_ff, _cf); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c2, _beta); - _f2 = vmlaq_f32(_f2, _c4, _beta); - _f3 = vmlaq_f32(_f3, _c6, _beta); - _f4 = vmlaq_f32(_f4, _c8, _beta); - _f5 = vmlaq_f32(_f5, _ca, _beta); - _f6 = vmlaq_f32(_f6, _cc, _beta); - _f7 = vmlaq_f32(_f7, _ce, _beta); - _f8 = vmlaq_f32(_f8, _c1, _beta); - _f9 = vmlaq_f32(_f9, _c3, _beta); - _fa = vmlaq_f32(_fa, _c5, _beta); - _fb = vmlaq_f32(_fb, _c7, _beta); - _fc = vmlaq_f32(_fc, _c9, _beta); - _fd = vmlaq_f32(_fd, _cb, _beta); - _fe = vmlaq_f32(_fe, _cd, _beta); - _ff = vmlaq_f32(_ff, _cf, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } pC += 8; } @@ -9026,20 +9027,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (c_elempack == 1) { - // TODO decompose 4x8 to 4x4 and 4x4 uint16x4_t _cc0 = vld1_u16(pC); uint16x4_t _cc1 = vld1_u16(pC + c_hstep); uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - uint16x4_t _cc4 = vld1_u16(pC + c_hstep * 4); - uint16x4_t _cc5 = vld1_u16(pC + c_hstep * 5); - uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6); - uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7); - transpose4x8_u16(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc2); - float32x4_t _c2 = bfloat2float(_cc4); - float32x4_t _c3 = bfloat2float(_cc6); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -9055,10 +9051,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - _c0 = bfloat2float(_cc1); - _c1 = bfloat2float(_cc3); - _c2 = bfloat2float(_cc5); - _c3 = bfloat2float(_cc7); + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f4 = vaddq_f32(_f4, _c0);