Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 13, 2023
1 parent d0e46fd commit 1375a03
Showing 1 changed file with 128 additions and 18 deletions.
146 changes: 128 additions & 18 deletions src/layer/arm/convolution_im2col_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -7302,11 +7302,24 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo
if (elempack == 8)
{
const signed char* p0 = (const signed char*)bottom_blob.channel(k / 8) + (j + jj) * 8;
const int cstep = bottom_blob.cstep * 8;

int kk = 0;
#if __ARM_FEATURE_MATMUL_INT8
for (; kk < max_kk / 8; kk++)
{
#if NCNN_GNU_INLINE_ASM
asm volatile(
"prfm pldl1keep, [%0, #512] \n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], %4 \n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%1], #64 \n"
: "=r"(p0), // %0
"=r"(pp) // %1
: "0"(p0),
"1"(pp),
"r"(cstep)
: "memory", "v0", "v1", "v2", "v3");
#else // NCNN_GNU_INLINE_ASM
int8x16_t _r01 = vld1q_s8(p0);
int8x16_t _r23 = vld1q_s8(p0 + 16);
int8x16_t _r45 = vld1q_s8(p0 + 32);
Expand All @@ -7316,51 +7329,113 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo
vst1q_s8(pp + 32, _r45);
vst1q_s8(pp + 48, _r67);
pp += 64;
p0 += bottom_blob.cstep * 8;
p0 += cstep;
#endif // NCNN_GNU_INLINE_ASM
}
#elif __ARM_FEATURE_DOTPROD
for (; kk < max_kk / 8; kk++)
{
#if NCNN_GNU_INLINE_ASM
asm volatile(
"prfm pldl1keep, [%0, #512] \n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], %4 \n"
"uzp1 v4.4s, v0.4s, v1.4s \n"
"uzp2 v6.4s, v0.4s, v1.4s \n"
"uzp1 v5.4s, v2.4s, v3.4s \n"
"uzp2 v7.4s, v2.4s, v3.4s \n"
"st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%1], #64 \n"
: "=r"(p0), // %0
"=r"(pp) // %1
: "0"(p0),
"1"(pp),
"r"(cstep)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#else // NCNN_GNU_INLINE_ASM
int32x4x2_t _r0246 = vld2q_s32((const int*)p0);
int32x4x2_t _r1357 = vld2q_s32((const int*)(p0 + 32));
vst1q_s32((int*)pp, _r0246.val[0]);
vst1q_s32((int*)(pp + 16), _r1357.val[0]);
vst1q_s32((int*)(pp + 32), _r0246.val[1]);
vst1q_s32((int*)(pp + 48), _r1357.val[1]);
pp += 64;
p0 += bottom_blob.cstep * 8;
p0 += cstep;
#endif // NCNN_GNU_INLINE_ASM
}
#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD
for (; kk < max_kk / 8; kk++)
{
#if NCNN_GNU_INLINE_ASM
asm volatile(
"prfm pldl1keep, [%0, #512] \n"
"ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], %4 \n"
"st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
: "=r"(p0), // %0
"=r"(pp) // %1
: "0"(p0),
"1"(pp),
"r"(cstep)
: "memory", "v0", "v1", "v2", "v3");
#else // NCNN_GNU_INLINE_ASM
int16x8x4_t _r0 = vld4q_s16((const short*)p0);
vst1q_s16((short*)pp, _r0.val[0]);
vst1q_s16((short*)(pp + 16), _r0.val[1]);
vst1q_s16((short*)(pp + 32), _r0.val[2]);
vst1q_s16((short*)(pp + 48), _r0.val[3]);
pp += 64;
p0 += bottom_blob.cstep * 8;
p0 += cstep;
#endif // NCNN_GNU_INLINE_ASM
}
#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD
}

if (elempack == 1)
{
const signed char* p0 = (const signed char*)bottom_blob.channel(k) + (j + jj);
const int cstep = bottom_blob.cstep;

int kk = 0;
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_MATMUL_INT8
for (; kk + 7 < max_kk; kk += 8)
{
#if NCNN_GNU_INLINE_ASM
asm volatile(
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v0.8b}, [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v1.8b}, [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v0.d}[1], [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v1.d}[1], [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v2.8b}, [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v3.8b}, [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v2.d}[1], [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v3.d}[1], [%0], %4 \n"
"zip1 v4.16b, v0.16b, v1.16b \n"
"zip2 v5.16b, v0.16b, v1.16b \n"
"zip1 v6.16b, v2.16b, v3.16b \n"
"zip2 v7.16b, v2.16b, v3.16b \n"
"st4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%1], #64 \n"
: "=r"(p0), // %0
"=r"(pp) // %1
: "0"(p0),
"1"(pp),
"r"(cstep)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#else // NCNN_GNU_INLINE_ASM
int8x8_t _r0 = vld1_s8(p0);
int8x8_t _r1 = vld1_s8(p0 + bottom_blob.cstep);
int8x8_t _r2 = vld1_s8(p0 + bottom_blob.cstep * 2);
int8x8_t _r3 = vld1_s8(p0 + bottom_blob.cstep * 3);
int8x8_t _r4 = vld1_s8(p0 + bottom_blob.cstep * 4);
int8x8_t _r5 = vld1_s8(p0 + bottom_blob.cstep * 5);
int8x8_t _r6 = vld1_s8(p0 + bottom_blob.cstep * 6);
int8x8_t _r7 = vld1_s8(p0 + bottom_blob.cstep * 7);
int8x8_t _r1 = vld1_s8(p0 + cstep);
int8x8_t _r2 = vld1_s8(p0 + cstep * 2);
int8x8_t _r3 = vld1_s8(p0 + cstep * 3);
int8x8_t _r4 = vld1_s8(p0 + cstep * 4);
int8x8_t _r5 = vld1_s8(p0 + cstep * 5);
int8x8_t _r6 = vld1_s8(p0 + cstep * 6);
int8x8_t _r7 = vld1_s8(p0 + cstep * 7);
// save as transpose8x8
int8x8x2_t _r01 = vzip_s8(_r0, _r1);
int8x8x2_t _r23 = vzip_s8(_r2, _r3);
Expand All @@ -7373,35 +7448,70 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo
_r0246.val[3] = vreinterpretq_s16_s8(vcombine_s8(_r67.val[0], _r67.val[1]));
vst4q_s16((short*)pp, _r0246);
pp += 64;
p0 += bottom_blob.cstep * 8;
p0 += cstep * 8;
#endif // NCNN_GNU_INLINE_ASM
}
#endif // __ARM_FEATURE_MATMUL_INT8
for (; kk + 3 < max_kk; kk += 4)
{
#if NCNN_GNU_INLINE_ASM
asm volatile(
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v0.8b}, [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v1.8b}, [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v2.8b}, [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v3.8b}, [%0], %4 \n"
"st4 {v0.8b, v1.8b, v2.8b, v3.8b}, [%1], #32 \n"
: "=r"(p0), // %0
"=r"(pp) // %1
: "0"(p0),
"1"(pp),
"r"(cstep)
: "memory", "v0", "v1", "v2", "v3");
#else // NCNN_GNU_INLINE_ASM
int8x8x4_t _r01;
_r01.val[0] = vld1_s8(p0);
_r01.val[1] = vld1_s8(p0 + bottom_blob.cstep);
_r01.val[2] = vld1_s8(p0 + bottom_blob.cstep * 2);
_r01.val[3] = vld1_s8(p0 + bottom_blob.cstep * 3);
_r01.val[1] = vld1_s8(p0 + cstep);
_r01.val[2] = vld1_s8(p0 + cstep * 2);
_r01.val[3] = vld1_s8(p0 + cstep * 3);
vst4_s8(pp, _r01);
pp += 32;
p0 += bottom_blob.cstep * 4;
p0 += cstep * 4;
#endif // NCNN_GNU_INLINE_ASM
}
#endif // __ARM_FEATURE_DOTPROD
for (; kk + 1 < max_kk; kk += 2)
{
#if NCNN_GNU_INLINE_ASM
asm volatile(
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v0.8b}, [%0], %4 \n"
"prfm pldl1keep, [%0, #64] \n"
"ld1 {v1.8b}, [%0], %4 \n"
"st2 {v0.8b, v1.8b}, [%1], #16 \n"
: "=r"(p0), // %0
"=r"(pp) // %1
: "0"(p0),
"1"(pp),
"r"(cstep)
: "memory", "v0", "v1");
#else // NCNN_GNU_INLINE_ASM
int8x8x2_t _r01;
_r01.val[0] = vld1_s8(p0);
_r01.val[1] = vld1_s8(p0 + bottom_blob.cstep);
_r01.val[1] = vld1_s8(p0 + cstep);
vst2_s8(pp, _r01);
pp += 16;
p0 += bottom_blob.cstep * 2;
p0 += cstep * 2;
#endif // NCNN_GNU_INLINE_ASM
}
for (; kk < max_kk; kk++)
{
vst1_s8(pp, vld1_s8(p0));
pp += 8;
p0 += bottom_blob.cstep;
p0 += cstep;
}
}
}
Expand Down

0 comments on commit 1375a03

Please sign in to comment.