From 9c3057bb02f6dbe0bc9d8040c9d6394d47a6ea74 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 10 Oct 2024 15:30:47 +0800 Subject: [PATCH] fix test, test++ --- src/layer/arm/gemm_int8_fp16s.h | 48 ++++++++++++++++----------------- tests/test_gemm_3.cpp | 7 +++++ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index 34dde2629de..05dd8975da0 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -2275,14 +2275,14 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD - int8x8_t _r0 = float2int8(_p0, _p1); - int8x8_t _r1 = float2int8(_p2, _p3); - int8x8_t _r2 = float2int8(_p4, _p5); - int8x8_t _r3 = float2int8(_p6, _p7); - int8x8_t _r4 = float2int8(_p8, _p9); - int8x8_t _r5 = float2int8(_pa, _pb); - int8x8_t _r6 = float2int8(_pc, _pd); - int8x8_t _r7 = float2int8(_pe, _pf); + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); @@ -2720,10 +2720,10 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD - int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p4)); - int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p5)); - int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p2, _p6)); - int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p3, _p7)); + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); int16x4x2_t _t23 = vuzp_s16(_t2, _t3); int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); @@ -4678,14 +4678,14 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD - int8x8_t _r0 = float2int8(_p0, _p1); - int8x8_t _r1 = float2int8(_p2, _p3); - int8x8_t _r2 = float2int8(_p4, _p5); - int8x8_t _r3 = float2int8(_p6, _p7); - int8x8_t _r4 = float2int8(_p8, _p9); - int8x8_t _r5 = float2int8(_pa, _pb); - int8x8_t _r6 = float2int8(_pc, _pd); - int8x8_t _r7 = float2int8(_pe, _pf); + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); @@ -5080,10 +5080,10 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD - int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p4)); - int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p5)); - int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p2, _p6)); - int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p3, _p7)); + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); int16x4x2_t _t23 = vuzp_s16(_t2, _t3); int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index 5348f8d674e..f753a6594fa 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -243,6 +243,13 @@ int main() int ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); if (ret != 0) return ret; + + if (M != N) + { + int ret = test_gemm_0(N, M, K) || test_gemm_1(N, M, K); + if (ret != 0) + return ret; + } } #else // test nothing for non-int8 build