Skip to content

Commit

Permalink
reuse rms for sqsum and affine
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 3, 2024
1 parent e778300 commit 2064a25
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 95 deletions.
88 changes: 40 additions & 48 deletions src/layer/arm/rmsnorm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount
const int size = elemcount * elempack;

#if __ARM_NEON
float32x4_t _sqsum = vdupq_n_f32(0.f);
float32x4_t _rms = vdupq_n_f32(0.f);
#endif // __ARM_NEON
float sqsum = 0.f;
float rms = 0.f;
{
const float* ptr0 = ptr;

Expand All @@ -53,56 +53,53 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount
for (; i + 3 < size; i += 4)
{
float32x4_t _p = vld1q_f32(ptr0);
_sqsum = vmlaq_f32(_sqsum, _p, _p);
_rms = vmlaq_f32(_rms, _p, _p);
ptr0 += 4;
}
#endif // __ARM_NEON
for (; i < size; i++)
{
sqsum += ptr0[0] * ptr0[0];
rms += ptr0[0] * ptr0[0];
ptr0++;
}
}

#if __ARM_NEON
float32x4_t _a;
if (elempack == 4)
{
float32x4_t _elemcount = vdupq_n_f32(elemcount);
float32x4_t _eps = vdupq_n_f32(eps);

#if __aarch64__
_sqsum = vdivq_f32(_sqsum, _elemcount);
_sqsum = vaddq_f32(_sqsum, _eps);
_rms = vdivq_f32(_rms, _elemcount);
_rms = vaddq_f32(_rms, _eps);
#else
float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount);
_inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount);
_inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount);
_sqsum = vmlaq_f32(_eps, _sqsum, _inv_elemcount);
_rms = vmlaq_f32(_eps, _rms, _inv_elemcount);
#endif

_a = vrsqrteq_f32(_sqsum);
_a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a);
_a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a);
float32x4_t _rsqrt_rms = vrsqrteq_f32(_rms);
_rsqrt_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms);
_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms);
}
#endif // __ARM_NEON

float a;
if (elempack == 1)
{
#if __ARM_NEON
#if __aarch64__
sqsum += vaddvq_f32(_sqsum);
rms += vaddvq_f32(_rms);
#else
float32x2_t _s2 = vadd_f32(vget_low_f32(_sqsum), vget_high_f32(_sqsum));
float32x2_t _s2 = vadd_f32(vget_low_f32(_rms), vget_high_f32(_rms));
_s2 = vpadd_f32(_s2, _s2);
sqsum += vget_lane_f32(_s2, 0);
rms += vget_lane_f32(_s2, 0);
#endif
#endif // __ARM_NEON

a = 1.f / sqrtf(sqsum / elemcount + eps);
rms = 1.f / sqrtf(rms / elemcount + eps);
#if __ARM_NEON
_a = vdupq_n_f32(a);
_rms = vdupq_n_f32(rms);
#endif // __ARM_NEON
}

Expand All @@ -116,21 +113,20 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount
{
float32x4_t _p = vld1q_f32(ptr);
float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]);
_p = vmulq_f32(_p, _a);
_p = vmulq_f32(_p, _rms);
_p = vmulq_f32(_p, _gamma);
vst1q_f32(ptr, _p);
ptr += 4;
gamma_ptr += 1;
}
}

if (elempack == 1)
{
for (; i + 3 < size; i += 4)
{
float32x4_t _p = vld1q_f32(ptr);
float32x4_t _gamma = vld1q_f32(gamma_ptr);
_p = vmulq_f32(_p, _a);
_p = vmulq_f32(_p, _rms);
_p = vmulq_f32(_p, _gamma);
vst1q_f32(ptr, _p);
ptr += 4;
Expand All @@ -140,7 +136,7 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount
#endif // __ARM_NEON
for (; i < size; i++)
{
ptr[0] = (ptr[0] * a) * gamma_ptr[0];
ptr[0] = (ptr[0] * rms) * gamma_ptr[0];
ptr++;
gamma_ptr++;
}
Expand All @@ -152,14 +148,14 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount
for (; i + 3 < size; i += 4)
{
float32x4_t _p = vld1q_f32(ptr);
_p = vmulq_f32(_p, _a);
_p = vmulq_f32(_p, _rms);
vst1q_f32(ptr, _p);
ptr += 4;
}
#endif // __ARM_NEON
for (; i < size; i++)
{
ptr[0] = ptr[0] * a;
ptr[0] = ptr[0] * rms;
ptr++;
}
}
Expand Down Expand Up @@ -239,9 +235,9 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps
const int size = elemcount * elempack;

#if __ARM_NEON
float32x4_t _sqsum = vdupq_n_f32(0.f);
float32x4_t _rms = vdupq_n_f32(0.f);
#endif // __ARM_NEON
float sqsum = 0.f;
float rms = 0.f;
{
const unsigned short* ptr0 = ptr;

Expand All @@ -250,57 +246,54 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps
for (; i + 3 < size; i += 4)
{
float32x4_t _p = bfloat2float(vld1_u16(ptr0));
_sqsum = vmlaq_f32(_sqsum, _p, _p);
_rms = vmlaq_f32(_rms, _p, _p);
ptr0 += 4;
}
#endif // __ARM_NEON
for (; i < size; i++)
{
float v = bfloat16_to_float32(ptr0[0]);
sqsum += v * v;
rms += v * v;
ptr0++;
}
}

#if __ARM_NEON
float32x4_t _a;
if (elempack == 4)
{
float32x4_t _elemcount = vdupq_n_f32(elemcount);
float32x4_t _eps = vdupq_n_f32(eps);

#if __aarch64__
_sqsum = vdivq_f32(_sqsum, _elemcount);
_sqsum = vaddq_f32(_sqsum, _eps);
_rms = vdivq_f32(_rms, _elemcount);
_rms = vaddq_f32(_rms, _eps);
#else
float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount);
_inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount);
_inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount);
_sqsum = vmlaq_f32(_eps, _sqsum, _inv_elemcount);
_rms = vmlaq_f32(_eps, _rms, _inv_elemcount);
#endif

_a = vrsqrteq_f32(_sqsum);
_a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a);
_a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a);
float32x4_t _rsqrt_rms = vrsqrteq_f32(_rms);
_rsqrt_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms);
_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms);
}
#endif // __ARM_NEON

float a;
if (elempack == 1)
{
#if __ARM_NEON
#if __aarch64__
sqsum += vaddvq_f32(_sqsum);
rms += vaddvq_f32(_rms);
#else
float32x2_t _s2 = vadd_f32(vget_low_f32(_sqsum), vget_high_f32(_sqsum));
float32x2_t _s2 = vadd_f32(vget_low_f32(_rms), vget_high_f32(_rms));
_s2 = vpadd_f32(_s2, _s2);
sqsum += vget_lane_f32(_s2, 0);
rms += vget_lane_f32(_s2, 0);
#endif
#endif // __ARM_NEON

a = 1.f / sqrtf(sqsum / elemcount + eps);
rms = 1.f / sqrtf(rms / elemcount + eps);
#if __ARM_NEON
_a = vdupq_n_f32(a);
_rms = vdupq_n_f32(rms);
#endif // __ARM_NEON
}

Expand All @@ -314,21 +307,20 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps
{
float32x4_t _p = bfloat2float(vld1_u16(ptr));
float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]);
_p = vmulq_f32(_p, _a);
_p = vmulq_f32(_p, _rms);
_p = vmulq_f32(_p, _gamma);
vst1_u16(ptr, float2bfloat(_p));
ptr += 4;
gamma_ptr += 1;
}
}

if (elempack == 1)
{
for (; i + 3 < size; i += 4)
{
float32x4_t _p = bfloat2float(vld1_u16(ptr));
float32x4_t _gamma = vld1q_f32(gamma_ptr);
_p = vmulq_f32(_p, _a);
_p = vmulq_f32(_p, _rms);
_p = vmulq_f32(_p, _gamma);
vst1_u16(ptr, float2bfloat(_p));
ptr += 4;
Expand All @@ -339,7 +331,7 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps
for (; i < size; i++)
{
float v = bfloat16_to_float32(ptr[0]);
ptr[0] = float32_to_bfloat16((v * a) * gamma_ptr[0]);
ptr[0] = float32_to_bfloat16((v * rms) * gamma_ptr[0]);
ptr++;
gamma_ptr++;
}
Expand All @@ -351,15 +343,15 @@ static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps
for (; i + 3 < size; i += 4)
{
float32x4_t _p = bfloat2float(vld1_u16(ptr));
_p = vmulq_f32(_p, _a);
_p = vmulq_f32(_p, _rms);
vst1_u16(ptr, float2bfloat(_p));
ptr += 4;
}
#endif // __ARM_NEON
for (; i < size; i++)
{
float v = bfloat16_to_float32(ptr[0]);
ptr[0] = float32_to_bfloat16(v * a);
ptr[0] = float32_to_bfloat16(v * rms);
ptr++;
}
}
Expand Down
Loading

0 comments on commit 2064a25

Please sign in to comment.