From 2064a2563ce1e25e00202a8c62dde49c1735289f Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 3 Sep 2024 16:23:39 +0800 Subject: [PATCH] reuse rms for sqsum and affine --- src/layer/arm/rmsnorm_arm.cpp | 88 ++++++++++++-------------- src/layer/arm/rmsnorm_arm_asimdhp.cpp | 89 +++++++++++++-------------- 2 files changed, 82 insertions(+), 95 deletions(-) diff --git a/src/layer/arm/rmsnorm_arm.cpp b/src/layer/arm/rmsnorm_arm.cpp index 18b363387b5..e19136ca29d 100644 --- a/src/layer/arm/rmsnorm_arm.cpp +++ b/src/layer/arm/rmsnorm_arm.cpp @@ -42,9 +42,9 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount const int size = elemcount * elempack; #if __ARM_NEON - float32x4_t _sqsum = vdupq_n_f32(0.f); + float32x4_t _rms = vdupq_n_f32(0.f); #endif // __ARM_NEON - float sqsum = 0.f; + float rms = 0.f; { const float* ptr0 = ptr; @@ -53,56 +53,53 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount for (; i + 3 < size; i += 4) { float32x4_t _p = vld1q_f32(ptr0); - _sqsum = vmlaq_f32(_sqsum, _p, _p); + _rms = vmlaq_f32(_rms, _p, _p); ptr0 += 4; } #endif // __ARM_NEON for (; i < size; i++) { - sqsum += ptr0[0] * ptr0[0]; + rms += ptr0[0] * ptr0[0]; ptr0++; } } #if __ARM_NEON - float32x4_t _a; if (elempack == 4) { float32x4_t _elemcount = vdupq_n_f32(elemcount); float32x4_t _eps = vdupq_n_f32(eps); #if __aarch64__ - _sqsum = vdivq_f32(_sqsum, _elemcount); - _sqsum = vaddq_f32(_sqsum, _eps); + _rms = vdivq_f32(_rms, _elemcount); + _rms = vaddq_f32(_rms, _eps); #else float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount); _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); - _sqsum = vmlaq_f32(_eps, _sqsum, _inv_elemcount); + _rms = vmlaq_f32(_eps, _rms, _inv_elemcount); #endif - _a = vrsqrteq_f32(_sqsum); - _a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a); - _a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a); + float32x4_t _rsqrt_rms = vrsqrteq_f32(_rms); + _rsqrt_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + _rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); } #endif // __ARM_NEON - - float a; if (elempack == 1) { #if __ARM_NEON #if __aarch64__ - sqsum += vaddvq_f32(_sqsum); + rms += vaddvq_f32(_rms); #else - float32x2_t _s2 = vadd_f32(vget_low_f32(_sqsum), vget_high_f32(_sqsum)); + float32x2_t _s2 = vadd_f32(vget_low_f32(_rms), vget_high_f32(_rms)); _s2 = vpadd_f32(_s2, _s2); - sqsum += vget_lane_f32(_s2, 0); + rms += vget_lane_f32(_s2, 0); #endif #endif // __ARM_NEON - a = 1.f / sqrtf(sqsum / elemcount + eps); + rms = 1.f / sqrtf(rms / elemcount + eps); #if __ARM_NEON - _a = vdupq_n_f32(a); + _rms = vdupq_n_f32(rms); #endif // __ARM_NEON } @@ -116,21 +113,20 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount { float32x4_t _p = vld1q_f32(ptr); float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); - _p = vmulq_f32(_p, _a); + _p = vmulq_f32(_p, _rms); _p = vmulq_f32(_p, _gamma); vst1q_f32(ptr, _p); ptr += 4; gamma_ptr += 1; } } - if (elempack == 1) { for (; i + 3 < size; i += 4) { float32x4_t _p = vld1q_f32(ptr); float32x4_t _gamma = vld1q_f32(gamma_ptr); - _p = vmulq_f32(_p, _a); + _p = vmulq_f32(_p, _rms); _p = vmulq_f32(_p, _gamma); vst1q_f32(ptr, _p); ptr += 4; @@ -140,7 +136,7 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount #endif // __ARM_NEON for (; i < size; i++) { - ptr[0] = (ptr[0] * a) * gamma_ptr[0]; + ptr[0] = (ptr[0] * rms) * gamma_ptr[0]; ptr++; gamma_ptr++; } @@ -152,14 +148,14 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount for (; i + 3 < size; i += 4) { float32x4_t _p = vld1q_f32(ptr); - _p = vmulq_f32(_p, _a); + _p = vmulq_f32(_p, _rms); vst1q_f32(ptr, _p); ptr += 4; } #endif // __ARM_NEON for (; i < size; i++) { - ptr[0] = ptr[0] * a; + ptr[0] = ptr[0] * rms; ptr++; } } @@ -239,9 +235,9 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps const int size = elemcount * elempack; #if __ARM_NEON - float32x4_t _sqsum = vdupq_n_f32(0.f); + float32x4_t _rms = vdupq_n_f32(0.f); #endif // __ARM_NEON - float sqsum = 0.f; + float rms = 0.f; { const unsigned short* ptr0 = ptr; @@ -250,57 +246,54 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps for (; i + 3 < size; i += 4) { float32x4_t _p = bfloat2float(vld1_u16(ptr0)); - _sqsum = vmlaq_f32(_sqsum, _p, _p); + _rms = vmlaq_f32(_rms, _p, _p); ptr0 += 4; } #endif // __ARM_NEON for (; i < size; i++) { float v = bfloat16_to_float32(ptr0[0]); - sqsum += v * v; + rms += v * v; ptr0++; } } #if __ARM_NEON - float32x4_t _a; if (elempack == 4) { float32x4_t _elemcount = vdupq_n_f32(elemcount); float32x4_t _eps = vdupq_n_f32(eps); #if __aarch64__ - _sqsum = vdivq_f32(_sqsum, _elemcount); - _sqsum = vaddq_f32(_sqsum, _eps); + _rms = vdivq_f32(_rms, _elemcount); + _rms = vaddq_f32(_rms, _eps); #else float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount); _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); - _sqsum = vmlaq_f32(_eps, _sqsum, _inv_elemcount); + _rms = vmlaq_f32(_eps, _rms, _inv_elemcount); #endif - _a = vrsqrteq_f32(_sqsum); - _a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a); - _a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a); + float32x4_t _rsqrt_rms = vrsqrteq_f32(_rms); + _rsqrt_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + _rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); } #endif // __ARM_NEON - - float a; if (elempack == 1) { #if __ARM_NEON #if __aarch64__ - sqsum += vaddvq_f32(_sqsum); + rms += vaddvq_f32(_rms); #else - float32x2_t _s2 = vadd_f32(vget_low_f32(_sqsum), vget_high_f32(_sqsum)); + float32x2_t _s2 = vadd_f32(vget_low_f32(_rms), vget_high_f32(_rms)); _s2 = vpadd_f32(_s2, _s2); - sqsum += vget_lane_f32(_s2, 0); + rms += vget_lane_f32(_s2, 0); #endif #endif // __ARM_NEON - a = 1.f / sqrtf(sqsum / elemcount + eps); + rms = 1.f / sqrtf(rms / elemcount + eps); #if __ARM_NEON - _a = vdupq_n_f32(a); + _rms = vdupq_n_f32(rms); #endif // __ARM_NEON } @@ -314,21 +307,20 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps { float32x4_t _p = bfloat2float(vld1_u16(ptr)); float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); - _p = vmulq_f32(_p, _a); + _p = vmulq_f32(_p, _rms); _p = vmulq_f32(_p, _gamma); vst1_u16(ptr, float2bfloat(_p)); ptr += 4; gamma_ptr += 1; } } - if (elempack == 1) { for (; i + 3 < size; i += 4) { float32x4_t _p = bfloat2float(vld1_u16(ptr)); float32x4_t _gamma = vld1q_f32(gamma_ptr); - _p = vmulq_f32(_p, _a); + _p = vmulq_f32(_p, _rms); _p = vmulq_f32(_p, _gamma); vst1_u16(ptr, float2bfloat(_p)); ptr += 4; @@ -339,7 +331,7 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps for (; i < size; i++) { float v = bfloat16_to_float32(ptr[0]); - ptr[0] = float32_to_bfloat16((v * a) * gamma_ptr[0]); + ptr[0] = float32_to_bfloat16((v * rms) * gamma_ptr[0]); ptr++; gamma_ptr++; } @@ -351,7 +343,7 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps for (; i + 3 < size; i += 4) { float32x4_t _p = bfloat2float(vld1_u16(ptr)); - _p = vmulq_f32(_p, _a); + _p = vmulq_f32(_p, _rms); vst1_u16(ptr, float2bfloat(_p)); ptr += 4; } @@ -359,7 +351,7 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps for (; i < size; i++) { float v = bfloat16_to_float32(ptr[0]); - ptr[0] = float32_to_bfloat16(v * a); + ptr[0] = float32_to_bfloat16(v * rms); ptr++; } } diff --git a/src/layer/arm/rmsnorm_arm_asimdhp.cpp b/src/layer/arm/rmsnorm_arm_asimdhp.cpp index 74fcb7d7a37..98d8e696487 100644 --- a/src/layer/arm/rmsnorm_arm_asimdhp.cpp +++ b/src/layer/arm/rmsnorm_arm_asimdhp.cpp @@ -26,9 +26,9 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el { const int size = elemcount * elempack; - float32x4_t _sqsum0 = vdupq_n_f32(0.f); - float32x4_t _sqsum1 = vdupq_n_f32(0.f); - float sqsum = 0.f; + float32x4_t _rms0 = vdupq_n_f32(0.f); + float32x4_t _rms1 = vdupq_n_f32(0.f); + float rms = 0.f; { const __fp16* ptr0 = ptr; @@ -38,66 +38,63 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el float16x8_t _p = vld1q_f16(ptr0); float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); - _sqsum0 = vmlaq_f32(_sqsum0, _p0, _p0); - _sqsum1 = vmlaq_f32(_sqsum1, _p1, _p1); + _rms0 = vmlaq_f32(_rms0, _p0, _p0); + _rms1 = vmlaq_f32(_rms1, _p1, _p1); ptr0 += 8; } for (; i + 3 < size; i += 4) { float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr0)); - _sqsum0 = vmlaq_f32(_sqsum0, _p, _p); + _rms0 = vmlaq_f32(_rms0, _p, _p); ptr0 += 4; } for (; i < size; i++) { - sqsum += (float)ptr0[0] * (float)ptr0[0]; + rms += (float)ptr0[0] * (float)ptr0[0]; ptr0++; } } - float32x4_t _a0; - float32x4_t _a1; - float a; if (elempack == 8) { float32x4_t _elemcount = vdupq_n_f32(elemcount); float32x4_t _eps = vdupq_n_f32(eps); - _sqsum0 = vdivq_f32(_sqsum0, _elemcount); - _sqsum1 = vdivq_f32(_sqsum1, _elemcount); - _sqsum0 = vaddq_f32(_sqsum0, _eps); - _sqsum1 = vaddq_f32(_sqsum1, _eps); + _rms0 = vdivq_f32(_rms0, _elemcount); + _rms1 = vdivq_f32(_rms1, _elemcount); + _rms0 = vaddq_f32(_rms0, _eps); + _rms1 = vaddq_f32(_rms1, _eps); - _a0 = vrsqrteq_f32(_sqsum0); - _a1 = vrsqrteq_f32(_sqsum1); - _a0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum0, _a0), _a0), _a0); - _a1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum1, _a1), _a1), _a1); - _a0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum0, _a0), _a0), _a0); - _a1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum1, _a1), _a1), _a1); + float32x4_t _rsqrt_rms0 = vrsqrteq_f32(_rms0); + float32x4_t _rsqrt_rms1 = vrsqrteq_f32(_rms1); + _rsqrt_rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rsqrt_rms1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms1, _rsqrt_rms1), _rsqrt_rms1), _rsqrt_rms1); + _rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms1, _rsqrt_rms1), _rsqrt_rms1), _rsqrt_rms1); } if (elempack == 4) { - _sqsum0 = vaddq_f32(_sqsum0, _sqsum1); + _rms0 = vaddq_f32(_rms0, _rms1); float32x4_t _elemcount = vdupq_n_f32(elemcount); float32x4_t _eps = vdupq_n_f32(eps); - _sqsum0 = vdivq_f32(_sqsum0, _elemcount); - _sqsum0 = vaddq_f32(_sqsum0, _eps); + _rms0 = vdivq_f32(_rms0, _elemcount); + _rms0 = vaddq_f32(_rms0, _eps); - _a0 = vrsqrteq_f32(_sqsum0); - _a0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum0, _a0), _a0), _a0); - _a0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum0, _a0), _a0), _a0); - _a1 = _a0; + float32x4_t _rsqrt_rms0 = vrsqrteq_f32(_rms0); + _rsqrt_rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms1 = _rms0; } if (elempack == 1) { - _sqsum0 = vaddq_f32(_sqsum0, _sqsum1); - sqsum += vaddvq_f32(_sqsum0); + _rms0 = vaddq_f32(_rms0, _rms1); + rms += vaddvq_f32(_rms0); - a = 1.f / sqrtf(sqsum / elemcount + eps); - _a0 = vdupq_n_f32(a); - _a1 = _a0; + rms = 1.f / sqrtf(rms / elemcount + eps); + _rms0 = vdupq_n_f32(rms); + _rms1 = _rms0; } if (gamma_ptr) @@ -111,8 +108,8 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); - _p0 = vmulq_f32(_p0, _a0); - _p1 = vmulq_f32(_p1, _a1); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); _p0 = vmulq_f32(_p0, _gamma); _p1 = vmulq_f32(_p1, _gamma); _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); @@ -121,7 +118,6 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el gamma_ptr += 1; } } - if (elempack == 4) { for (; i + 7 < size; i += 8) @@ -131,8 +127,8 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); float32x4_t _gamma0 = vdupq_n_f32(gamma_ptr[0]); float32x4_t _gamma1 = vdupq_n_f32(gamma_ptr[1]); - _p0 = vmulq_f32(_p0, _a0); - _p1 = vmulq_f32(_p1, _a1); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); _p0 = vmulq_f32(_p0, _gamma0); _p1 = vmulq_f32(_p1, _gamma1); _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); @@ -144,14 +140,13 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el { float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); - _p = vmulq_f32(_p, _a0); + _p = vmulq_f32(_p, _rms0); _p = vmulq_f32(_p, _gamma); vst1_f16(ptr, vcvt_f16_f32(_p)); ptr += 4; gamma_ptr += 1; } } - if (elempack == 1) { for (; i + 7 < size; i += 8) @@ -161,8 +156,8 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); float32x4_t _gamma0 = vld1q_f32(gamma_ptr); float32x4_t _gamma1 = vld1q_f32(gamma_ptr + 4); - _p0 = vmulq_f32(_p0, _a0); - _p1 = vmulq_f32(_p1, _a1); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); _p0 = vmulq_f32(_p0, _gamma0); _p1 = vmulq_f32(_p1, _gamma1); _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); @@ -174,7 +169,7 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el { float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); float32x4_t _gamma = vld1q_f32(gamma_ptr); - _p = vmulq_f32(_p, _a0); + _p = vmulq_f32(_p, _rms0); _p = vmulq_f32(_p, _gamma); vst1_f16(ptr, vcvt_f16_f32(_p)); ptr += 4; @@ -183,7 +178,7 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el } for (; i < size; i++) { - ptr[0] = (__fp16)(((float)ptr[0] * a) * gamma_ptr[0]); + ptr[0] = (__fp16)(((float)ptr[0] * rms) * gamma_ptr[0]); ptr++; gamma_ptr++; } @@ -196,8 +191,8 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el float16x8_t _p = vld1q_f16(ptr); float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); - _p0 = vmulq_f32(_p0, _a0); - _p1 = vmulq_f32(_p1, _a1); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); vst1q_f16(ptr, _p); ptr += 8; @@ -205,13 +200,13 @@ static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int el for (; i + 3 < size; i += 4) { float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); - _p = vmulq_f32(_p, _a0); + _p = vmulq_f32(_p, _rms0); vst1_f16(ptr, vcvt_f16_f32(_p)); ptr += 4; } for (; i < size; i++) { - ptr[0] = (__fp16)((float)ptr[0] * a); + ptr[0] = (__fp16)((float)ptr[0] * rms); ptr++; } }