Skip to content

Commit

Permalink
fix test, test++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 10, 2024
1 parent 73e7364 commit 9c3057b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
48 changes: 24 additions & 24 deletions src/layer/arm/gemm_int8_fp16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -2275,14 +2275,14 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int
vst1q_s8(pp + 48, vcombine_s8(_r6, _r7));
#endif // __ARM_FEATURE_MATMUL_INT8
#else // __ARM_FEATURE_DOTPROD
int8x8_t _r0 = float2int8(_p0, _p1);
int8x8_t _r1 = float2int8(_p2, _p3);
int8x8_t _r2 = float2int8(_p4, _p5);
int8x8_t _r3 = float2int8(_p6, _p7);
int8x8_t _r4 = float2int8(_p8, _p9);
int8x8_t _r5 = float2int8(_pa, _pb);
int8x8_t _r6 = float2int8(_pc, _pd);
int8x8_t _r7 = float2int8(_pe, _pf);
int8x8_t _r0 = float2int8(_p0, _p2);
int8x8_t _r1 = float2int8(_p4, _p6);
int8x8_t _r2 = float2int8(_p8, _pa);
int8x8_t _r3 = float2int8(_pc, _pe);
int8x8_t _r4 = float2int8(_p1, _p3);
int8x8_t _r5 = float2int8(_p5, _p7);
int8x8_t _r6 = float2int8(_p9, _pb);
int8x8_t _r7 = float2int8(_pd, _pf);

int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1));
int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3));
Expand Down Expand Up @@ -2720,10 +2720,10 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int
int8x8_t _r3 = float2int8(_p5, _p7);
#endif // __ARM_FEATURE_MATMUL_INT8
#else // __ARM_FEATURE_DOTPROD
int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p4));
int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p5));
int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p2, _p6));
int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p3, _p7));
int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2));
int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6));
int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3));
int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7));
int16x4x2_t _t01 = vuzp_s16(_t0, _t1);
int16x4x2_t _t23 = vuzp_s16(_t2, _t3);
int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]);
Expand Down Expand Up @@ -4678,14 +4678,14 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int
vst1q_s8(pp + 48, vcombine_s8(_r6, _r7));
#endif // __ARM_FEATURE_MATMUL_INT8
#else // __ARM_FEATURE_DOTPROD
int8x8_t _r0 = float2int8(_p0, _p1);
int8x8_t _r1 = float2int8(_p2, _p3);
int8x8_t _r2 = float2int8(_p4, _p5);
int8x8_t _r3 = float2int8(_p6, _p7);
int8x8_t _r4 = float2int8(_p8, _p9);
int8x8_t _r5 = float2int8(_pa, _pb);
int8x8_t _r6 = float2int8(_pc, _pd);
int8x8_t _r7 = float2int8(_pe, _pf);
int8x8_t _r0 = float2int8(_p0, _p2);
int8x8_t _r1 = float2int8(_p4, _p6);
int8x8_t _r2 = float2int8(_p8, _pa);
int8x8_t _r3 = float2int8(_pc, _pe);
int8x8_t _r4 = float2int8(_p1, _p3);
int8x8_t _r5 = float2int8(_p5, _p7);
int8x8_t _r6 = float2int8(_p9, _pb);
int8x8_t _r7 = float2int8(_pd, _pf);

int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1));
int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3));
Expand Down Expand Up @@ -5080,10 +5080,10 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int
int8x8_t _r3 = float2int8(_p5, _p7);
#endif // __ARM_FEATURE_MATMUL_INT8
#else // __ARM_FEATURE_DOTPROD
int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p4));
int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p5));
int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p2, _p6));
int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p3, _p7));
int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2));
int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6));
int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3));
int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7));
int16x4x2_t _t01 = vuzp_s16(_t0, _t1);
int16x4x2_t _t23 = vuzp_s16(_t2, _t3);
int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]);
Expand Down
7 changes: 7 additions & 0 deletions tests/test_gemm_3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,13 @@ int main()
int ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K);
if (ret != 0)
return ret;

if (M != N)
{
int ret = test_gemm_0(N, M, K) || test_gemm_1(N, M, K);
if (ret != 0)
return ret;
}
}
#else
// test nothing for non-int8 build
Expand Down

0 comments on commit 9c3057b

Please sign in to comment.