Skip to content

Commit

Permalink
opt++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 26, 2024
1 parent 6302017 commit dcd0636
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 128 deletions.
134 changes: 74 additions & 60 deletions src/layer/arm/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -10137,7 +10137,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
{
if (c_elempack == 1)
{
// TODO decompose 8x8 to 8x4 and 8x4
_c0 = vld1q_f32(pC);
_c1 = vld1q_f32(pC + 4);
float32x4_t _c2 = vld1q_f32(pC + c_hstep);
Expand All @@ -10146,53 +10145,61 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4);
float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3);
float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4);
float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4);
float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4);
float32x4_t _ca = vld1q_f32(pC + c_hstep * 5);
float32x4_t _cb = vld1q_f32(pC + c_hstep * 5 + 4);
float32x4_t _cc = vld1q_f32(pC + c_hstep * 6);
float32x4_t _cd = vld1q_f32(pC + c_hstep * 6 + 4);
float32x4_t _ce = vld1q_f32(pC + c_hstep * 7);
float32x4_t _cf = vld1q_f32(pC + c_hstep * 7 + 4);
transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7, _c8, _c9, _ca, _cb, _cc, _cd, _ce, _cf);
transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7);
if (beta == 1.f)
{
_f0 = vaddq_f32(_f0, _c0);
_f1 = vaddq_f32(_f1, _c2);
_f2 = vaddq_f32(_f2, _c4);
_f3 = vaddq_f32(_f3, _c6);
_f4 = vaddq_f32(_f4, _c8);
_f5 = vaddq_f32(_f5, _ca);
_f6 = vaddq_f32(_f6, _cc);
_f7 = vaddq_f32(_f7, _ce);
_f8 = vaddq_f32(_f8, _c1);
_f9 = vaddq_f32(_f9, _c3);
_fa = vaddq_f32(_fa, _c5);
_fb = vaddq_f32(_fb, _c7);
_fc = vaddq_f32(_fc, _c9);
_fd = vaddq_f32(_fd, _cb);
_fe = vaddq_f32(_fe, _cd);
_ff = vaddq_f32(_ff, _cf);
_f1 = vaddq_f32(_f1, _c1);
_f2 = vaddq_f32(_f2, _c2);
_f3 = vaddq_f32(_f3, _c3);
_f4 = vaddq_f32(_f4, _c4);
_f5 = vaddq_f32(_f5, _c5);
_f6 = vaddq_f32(_f6, _c6);
_f7 = vaddq_f32(_f7, _c7);
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f0 = vmlaq_f32(_f0, _c0, _beta);
_f1 = vmlaq_f32(_f1, _c2, _beta);
_f2 = vmlaq_f32(_f2, _c4, _beta);
_f3 = vmlaq_f32(_f3, _c6, _beta);
_f4 = vmlaq_f32(_f4, _c8, _beta);
_f5 = vmlaq_f32(_f5, _ca, _beta);
_f6 = vmlaq_f32(_f6, _cc, _beta);
_f7 = vmlaq_f32(_f7, _ce, _beta);
_f8 = vmlaq_f32(_f8, _c1, _beta);
_f9 = vmlaq_f32(_f9, _c3, _beta);
_fa = vmlaq_f32(_fa, _c5, _beta);
_fb = vmlaq_f32(_fb, _c7, _beta);
_fc = vmlaq_f32(_fc, _c9, _beta);
_fd = vmlaq_f32(_fd, _cb, _beta);
_fe = vmlaq_f32(_fe, _cd, _beta);
_ff = vmlaq_f32(_ff, _cf, _beta);
_f1 = vmlaq_f32(_f1, _c1, _beta);
_f2 = vmlaq_f32(_f2, _c2, _beta);
_f3 = vmlaq_f32(_f3, _c3, _beta);
_f4 = vmlaq_f32(_f4, _c4, _beta);
_f5 = vmlaq_f32(_f5, _c5, _beta);
_f6 = vmlaq_f32(_f6, _c6, _beta);
_f7 = vmlaq_f32(_f7, _c7, _beta);
}
_c0 = vld1q_f32(pC + c_hstep * 4);
_c1 = vld1q_f32(pC + c_hstep * 4 + 4);
_c2 = vld1q_f32(pC + c_hstep * 5);
_c3 = vld1q_f32(pC + c_hstep * 5 + 4);
_c4 = vld1q_f32(pC + c_hstep * 6);
_c5 = vld1q_f32(pC + c_hstep * 6 + 4);
_c6 = vld1q_f32(pC + c_hstep * 7);
_c7 = vld1q_f32(pC + c_hstep * 7 + 4);
transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7);
if (beta == 1.f)
{
_f8 = vaddq_f32(_f8, _c0);
_f9 = vaddq_f32(_f9, _c1);
_fa = vaddq_f32(_fa, _c2);
_fb = vaddq_f32(_fb, _c3);
_fc = vaddq_f32(_fc, _c4);
_fd = vaddq_f32(_fd, _c5);
_fe = vaddq_f32(_fe, _c6);
_ff = vaddq_f32(_ff, _c7);
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f8 = vmlaq_f32(_f8, _c0, _beta);
_f9 = vmlaq_f32(_f9, _c1, _beta);
_fa = vmlaq_f32(_fa, _c2, _beta);
_fb = vmlaq_f32(_fb, _c3, _beta);
_fc = vmlaq_f32(_fc, _c4, _beta);
_fd = vmlaq_f32(_fd, _c5, _beta);
_fe = vmlaq_f32(_fe, _c6, _beta);
_ff = vmlaq_f32(_ff, _c7, _beta);
}
pC += 8;
}
Expand Down Expand Up @@ -10451,38 +10458,45 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
{
if (c_elempack == 1)
{
// TODO decompose 4x8 to 4x4 and 4x4
_c0 = vld1q_f32(pC);
_c1 = vld1q_f32(pC + c_hstep);
float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2);
float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3);
float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4);
float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5);
float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6);
float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7);
transpose4x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7);
transpose4x4_ps(_c0, _c1, _c2, _c3);
if (beta == 1.f)
{
_f0 = vaddq_f32(_f0, _c0);
_f1 = vaddq_f32(_f1, _c2);
_f2 = vaddq_f32(_f2, _c4);
_f3 = vaddq_f32(_f3, _c6);
_f4 = vaddq_f32(_f4, _c1);
_f5 = vaddq_f32(_f5, _c3);
_f6 = vaddq_f32(_f6, _c5);
_f7 = vaddq_f32(_f7, _c7);
_f1 = vaddq_f32(_f1, _c1);
_f2 = vaddq_f32(_f2, _c2);
_f3 = vaddq_f32(_f3, _c3);
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f0 = vmlaq_f32(_f0, _c0, _beta);
_f1 = vmlaq_f32(_f1, _c2, _beta);
_f2 = vmlaq_f32(_f2, _c4, _beta);
_f3 = vmlaq_f32(_f3, _c6, _beta);
_f4 = vmlaq_f32(_f4, _c1, _beta);
_f5 = vmlaq_f32(_f5, _c3, _beta);
_f6 = vmlaq_f32(_f6, _c5, _beta);
_f7 = vmlaq_f32(_f7, _c7, _beta);
_f1 = vmlaq_f32(_f1, _c1, _beta);
_f2 = vmlaq_f32(_f2, _c2, _beta);
_f3 = vmlaq_f32(_f3, _c3, _beta);
}
_c0 = vld1q_f32(pC + c_hstep * 4);
_c1 = vld1q_f32(pC + c_hstep * 5);
_c2 = vld1q_f32(pC + c_hstep * 6);
_c3 = vld1q_f32(pC + c_hstep * 7);
transpose4x4_ps(_c0, _c1, _c2, _c3);
if (beta == 1.f)
{
_f4 = vaddq_f32(_f4, _c0);
_f5 = vaddq_f32(_f5, _c1);
_f6 = vaddq_f32(_f6, _c2);
_f7 = vaddq_f32(_f7, _c3);
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f4 = vmlaq_f32(_f4, _c0, _beta);
_f5 = vmlaq_f32(_f5, _c1, _beta);
_f6 = vmlaq_f32(_f6, _c2, _beta);
_f7 = vmlaq_f32(_f7, _c3, _beta);
}
pC += 4;
}
Expand Down
Loading

0 comments on commit dcd0636

Please sign in to comment.