diff --git a/src/layer/arm/rmsnorm_arm.cpp b/src/layer/arm/rmsnorm_arm.cpp new file mode 100644 index 000000000000..e19136ca29d6 --- /dev/null +++ b/src/layer/arm/rmsnorm_arm.cpp @@ -0,0 +1,417 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "rmsnorm_arm.h" + +#if __ARM_NEON +#include +#endif // __ARM_NEON + +#include "arm_usability.h" +#include "cpu.h" + +namespace ncnn { + +RMSNorm_arm::RMSNorm_arm() +{ +#if __ARM_NEON + support_packing = true; +#if NCNN_ARM82 + support_fp16_storage = cpu_support_arm_asimdhp(); +#endif +#endif // __ARM_NEON + +#if NCNN_BF16 + support_bf16_storage = true; +#endif +} + +static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __ARM_NEON + float32x4_t _rms = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float rms = 0.f; + { + const float* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr0); + _rms = vmlaq_f32(_rms, _p, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + rms += ptr0[0] * ptr0[0]; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + +#if __aarch64__ + _rms = vdivq_f32(_rms, _elemcount); + _rms = vaddq_f32(_rms, _eps); +#else + float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount); + _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); + _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); + _rms = vmlaq_f32(_eps, _rms, _inv_elemcount); +#endif + + float32x4_t _rsqrt_rms = vrsqrteq_f32(_rms); + _rsqrt_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + _rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + rms += vaddvq_f32(_rms); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_rms), vget_high_f32(_rms)); + _s2 = vpadd_f32(_s2, _s2); + rms += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + rms = 1.f / sqrtf(rms / elemcount + eps); +#if __ARM_NEON + _rms = vdupq_n_f32(rms); +#endif // __ARM_NEON + } + + if (gamma_ptr) + { + int i = 0; +#if __ARM_NEON + if (elempack == 4) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + _p = vmulq_f32(_p, _rms); + _p = vmulq_f32(_p, _gamma); + vst1q_f32(ptr, _p); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + _p = vmulq_f32(_p, _rms); + _p = vmulq_f32(_p, _gamma); + vst1q_f32(ptr, _p); + ptr += 4; + gamma_ptr += 4; + } + } +#endif // __ARM_NEON + for (; i < size; i++) + { + ptr[0] = (ptr[0] * rms) * gamma_ptr[0]; + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + _p = vmulq_f32(_p, _rms); + vst1q_f32(ptr, _p); + ptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + ptr[0] = ptr[0] * rms; + ptr++; + } + } +} + +int RMSNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int elembits = bottom_top_blob.elembits(); + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + return forward_inplace_fp16s(bottom_top_blob, opt); +#endif + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + float* ptr = bottom_top_blob; + rmsnorm(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.row(i); + rmsnorm(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.channel(q).row(i); + rmsnorm(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + rmsnorm(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} + +#if NCNN_BF16 +static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __ARM_NEON + float32x4_t _rms = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float rms = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr0)); + _rms = vmlaq_f32(_rms, _p, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr0[0]); + rms += v * v; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + +#if __aarch64__ + _rms = vdivq_f32(_rms, _elemcount); + _rms = vaddq_f32(_rms, _eps); +#else + float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount); + _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); + _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); + _rms = vmlaq_f32(_eps, _rms, _inv_elemcount); +#endif + + float32x4_t _rsqrt_rms = vrsqrteq_f32(_rms); + _rsqrt_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + _rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + rms += vaddvq_f32(_rms); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_rms), vget_high_f32(_rms)); + _s2 = vpadd_f32(_s2, _s2); + rms += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + rms = 1.f / sqrtf(rms / elemcount + eps); +#if __ARM_NEON + _rms = vdupq_n_f32(rms); +#endif // __ARM_NEON + } + + if (gamma_ptr) + { + int i = 0; +#if __ARM_NEON + if (elempack == 4) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + _p = vmulq_f32(_p, _rms); + _p = vmulq_f32(_p, _gamma); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + _p = vmulq_f32(_p, _rms); + _p = vmulq_f32(_p, _gamma); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + gamma_ptr += 4; + } + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr[0]); + ptr[0] = float32_to_bfloat16((v * rms) * gamma_ptr[0]); + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + _p = vmulq_f32(_p, _rms); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr[0]); + ptr[0] = float32_to_bfloat16(v * rms); + ptr++; + } + } +} + +int RMSNorm_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + unsigned short* ptr = bottom_top_blob; + rmsnorm_bf16s(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + rmsnorm_bf16s(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.channel(q).row(i); + rmsnorm_bf16s(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + rmsnorm_bf16s(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/arm/rmsnorm_arm.h b/src/layer/arm/rmsnorm_arm.h new file mode 100644 index 000000000000..440153333710 --- /dev/null +++ b/src/layer/arm/rmsnorm_arm.h @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_RMSNORM_ARM_H +#define LAYER_RMSNORM_ARM_H + +#include "rmsnorm.h" + +namespace ncnn { + +class RMSNorm_arm : public RMSNorm +{ +public: + RMSNorm_arm(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_ARM82 + int forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const; +#endif +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif +}; + +} // namespace ncnn + +#endif // LAYER_RMSNORM_ARM_H diff --git a/src/layer/arm/rmsnorm_arm_asimdhp.cpp b/src/layer/arm/rmsnorm_arm_asimdhp.cpp new file mode 100644 index 000000000000..98d8e6964876 --- /dev/null +++ b/src/layer/arm/rmsnorm_arm_asimdhp.cpp @@ -0,0 +1,272 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "rmsnorm_arm.h" + +#if __ARM_NEON +#include +#include "arm_usability.h" +#endif // __ARM_NEON + +namespace ncnn { + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + + float32x4_t _rms0 = vdupq_n_f32(0.f); + float32x4_t _rms1 = vdupq_n_f32(0.f); + float rms = 0.f; + { + const __fp16* ptr0 = ptr; + + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr0); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _rms0 = vmlaq_f32(_rms0, _p0, _p0); + _rms1 = vmlaq_f32(_rms1, _p1, _p1); + ptr0 += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr0)); + _rms0 = vmlaq_f32(_rms0, _p, _p); + ptr0 += 4; + } + for (; i < size; i++) + { + rms += (float)ptr0[0] * (float)ptr0[0]; + ptr0++; + } + } + + if (elempack == 8) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + + _rms0 = vdivq_f32(_rms0, _elemcount); + _rms1 = vdivq_f32(_rms1, _elemcount); + _rms0 = vaddq_f32(_rms0, _eps); + _rms1 = vaddq_f32(_rms1, _eps); + + float32x4_t _rsqrt_rms0 = vrsqrteq_f32(_rms0); + float32x4_t _rsqrt_rms1 = vrsqrteq_f32(_rms1); + _rsqrt_rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rsqrt_rms1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms1, _rsqrt_rms1), _rsqrt_rms1), _rsqrt_rms1); + _rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms1, _rsqrt_rms1), _rsqrt_rms1), _rsqrt_rms1); + } + if (elempack == 4) + { + _rms0 = vaddq_f32(_rms0, _rms1); + + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + + _rms0 = vdivq_f32(_rms0, _elemcount); + _rms0 = vaddq_f32(_rms0, _eps); + + float32x4_t _rsqrt_rms0 = vrsqrteq_f32(_rms0); + _rsqrt_rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms1 = _rms0; + } + if (elempack == 1) + { + _rms0 = vaddq_f32(_rms0, _rms1); + rms += vaddvq_f32(_rms0); + + rms = 1.f / sqrtf(rms / elemcount + eps); + _rms0 = vdupq_n_f32(rms); + _rms1 = _rms0; + } + + if (gamma_ptr) + { + int i = 0; + if (elempack == 8) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); + _p0 = vmulq_f32(_p0, _gamma); + _p1 = vmulq_f32(_p1, _gamma); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 1; + } + } + if (elempack == 4) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma0 = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _gamma1 = vdupq_n_f32(gamma_ptr[1]); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); + _p0 = vmulq_f32(_p0, _gamma0); + _p1 = vmulq_f32(_p1, _gamma1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 2; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + _p = vmulq_f32(_p, _rms0); + _p = vmulq_f32(_p, _gamma); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma0 = vld1q_f32(gamma_ptr); + float32x4_t _gamma1 = vld1q_f32(gamma_ptr + 4); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); + _p0 = vmulq_f32(_p0, _gamma0); + _p1 = vmulq_f32(_p1, _gamma1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + _p = vmulq_f32(_p, _rms0); + _p = vmulq_f32(_p, _gamma); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + gamma_ptr += 4; + } + } + for (; i < size; i++) + { + ptr[0] = (__fp16)(((float)ptr[0] * rms) * gamma_ptr[0]); + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vmulq_f32(_p, _rms0); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + } + for (; i < size; i++) + { + ptr[0] = (__fp16)((float)ptr[0] * rms); + ptr++; + } + } +} + +int RMSNorm_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + __fp16* ptr = bottom_top_blob; + rmsnorm_fp16s(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + rmsnorm_fp16s(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.channel(q).row<__fp16>(i); + rmsnorm_fp16s(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + __fp16* ptr = bottom_top_blob.channel(q); + rmsnorm_fp16s(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +} // namespace ncnn diff --git a/src/layer/x86/rmsnorm_x86.cpp b/src/layer/x86/rmsnorm_x86.cpp new file mode 100644 index 000000000000..db592c3e3810 --- /dev/null +++ b/src/layer/x86/rmsnorm_x86.cpp @@ -0,0 +1,413 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "rmsnorm_x86.h" + +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +namespace ncnn { + +RMSNorm_x86::RMSNorm_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + +static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _rms_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ + __m256 _rms_avx = _mm256_set1_ps(0.f); +#endif // __AVX__ + __m128 _rms = _mm_set1_ps(0.f); +#endif // __SSE2__ + float rms = 0.f; + { + const float* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr0); + _rms_avx512 = _mm512_fmadd_ps(_p, _p, _rms_avx512); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr0); + _rms_avx = _mm256_comp_fmadd_ps(_p, _p, _rms_avx); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr0); + _rms = _mm_comp_fmadd_ps(_p, _p, _rms); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + rms += ptr0[0] * ptr0[0]; + ptr0++; + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + __m512 _eps = _mm512_set1_ps(eps); + + _rms_avx512 = _mm512_div_ps(_rms_avx512, _elemcount); + _rms_avx512 = _mm512_add_ps(_rms_avx512, _eps); + + __m256 _rms0 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_rms_avx512, 0)); + __m256 _rms1 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_rms_avx512, 1)); + _rms_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_rms0), _rms1, 1); + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + { + __m256 _rms0 = _mm512_castps512_ps256(_rms_avx512); + __m256 _rms1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_rms_avx512), 1)); + _rms_avx = _mm256_add_ps(_rms_avx, _rms0); + _rms_avx = _mm256_add_ps(_rms_avx, _rms1); + } +#endif // __AVX512F__ + + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + __m256 _eps = _mm256_set1_ps(eps); + + _rms_avx = _mm256_div_ps(_rms_avx, _elemcount); + _rms_avx = _mm256_add_ps(_rms_avx, _eps); + + _rms_avx = _mm256_rsqrt_ps(_rms_avx); +#if __AVX512F__ + _rms_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_rms_avx), _rms_avx, 1); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + { + __m256 _rms0 = _mm512_castps512_ps256(_rms_avx512); + __m256 _rms1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_rms_avx512), 1)); + _rms_avx = _mm256_add_ps(_rms_avx, _rms0); + _rms_avx = _mm256_add_ps(_rms_avx, _rms1); + } +#endif // __AVX512F__ + { + __m128 _rms0 = _mm256_castps256_ps128(_rms_avx); + __m128 _rms1 = _mm256_extractf128_ps(_rms_avx, 1); + _rms = _mm_add_ps(_rms, _rms0); + _rms = _mm_add_ps(_rms, _rms1); + } +#endif // __AVX__ + + __m128 _elemcount = _mm_set1_ps((float)elemcount); + __m128 _eps = _mm_set1_ps(eps); + + _rms = _mm_div_ps(_rms, _elemcount); + _rms = _mm_add_ps(_rms, _eps); + + _rms = _mm_rsqrt_ps(_rms); +#if __AVX__ + _rms_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_rms), _rms, 1); +#if __AVX512F__ + _rms_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_rms_avx), _rms_avx, 1); +#endif // __AVX512F__ +#endif // __AVX__ + } +#endif // __SSE2__ + if (elempack == 1) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + rms += _mm512_comp_reduce_add_ps(_rms_avx512); +#endif // __AVX512F__ + rms += _mm256_reduce_add_ps(_rms_avx); +#endif // __AVX__ + rms += _mm_reduce_add_ps(_rms); +#endif // __SSE2__ + + rms = 1.f / sqrtf(rms / elemcount + eps); +#if __SSE2__ + _rms = _mm_set1_ps(rms); +#if __AVX__ + _rms_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_rms), _rms, 1); +#if __AVX512F__ + _rms_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_rms_avx), _rms_avx, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + + if (gamma_ptr) + { + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_set1_ps(gamma_ptr[0]); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 1; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m256 _gamma0 = _mm256_set1_ps(gamma_ptr[0]); + __m256 _gamma1 = _mm256_set1_ps(gamma_ptr[1]); + __m512 _gamma = _mm512_insertf32x8(_mm512_castps256_ps512(_gamma0), _gamma1, 1); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 2; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_set1_ps(gamma_ptr[0]); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 1; + } + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m128 _gamma2 = _mm_set1_ps(gamma_ptr[2]); + __m128 _gamma3 = _mm_set1_ps(gamma_ptr[3]); + __m256 _gamma01 = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma0), _gamma1, 1); + __m256 _gamma23 = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma2), _gamma3, 1); + __m512 _gamma = _mm512_insertf32x8(_mm512_castps256_ps512(_gamma01), _gamma23, 1); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 4; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m256 _gamma = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma0), _gamma1, 1); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 2; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_set1_ps(gamma_ptr[0]); + _p = _mm_mul_ps(_p, _rms); + _p = _mm_mul_ps(_p, _gamma); + _mm_storeu_ps(ptr, _p); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma_ptr); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma_ptr); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma_ptr); + _p = _mm_mul_ps(_p, _rms); + _p = _mm_mul_ps(_p, _gamma); + _mm_storeu_ps(ptr, _p); + ptr += 4; + gamma_ptr += 4; + } + } +#endif // __SSE2__ + for (; i < size; i++) + { + ptr[0] = (ptr[0] * rms) * gamma_ptr[0]; + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _p = _mm512_mul_ps(_p, _rms_avx512); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + _p = _mm256_mul_ps(_p, _rms_avx); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + _p = _mm_mul_ps(_p, _rms); + _mm_storeu_ps(ptr, _p); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + ptr[0] = ptr[0] * rms; + ptr++; + } + } +} + +int RMSNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + float* ptr = bottom_top_blob; + rmsnorm(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.row(i); + rmsnorm(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.channel(q).row(i); + rmsnorm(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + rmsnorm(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/rmsnorm_x86.h b/src/layer/x86/rmsnorm_x86.h new file mode 100644 index 000000000000..2e6296db1c32 --- /dev/null +++ b/src/layer/x86/rmsnorm_x86.h @@ -0,0 +1,32 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_RMSNORM_X86_H +#define LAYER_RMSNORM_X86_H + +#include "rmsnorm.h" + +namespace ncnn { + +class RMSNorm_x86 : public RMSNorm +{ +public: + RMSNorm_x86(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_RMSNORM_X86_H diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 39157c453ece..ff86338bca9c 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -99,6 +99,7 @@ #include "layer/reorg.h" #include "layer/requantize.h" #include "layer/reshape.h" +#include "layer/rmsnorm.h" #include "layer/rnn.h" #include "layer/roialign.h" #include "layer/roipooling.h" @@ -2313,6 +2314,17 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 2=%d", c) fprintf_param_value(" 3=%d", permute) } + else if (layer->type == "RMSNorm") + { + ncnn::RMSNorm* op = (ncnn::RMSNorm*)layer; + ncnn::RMSNorm* op_default = (ncnn::RMSNorm*)layer_default; + + fprintf_param_value(" 0=%d", affine_size) + fprintf_param_value(" 1=%e", eps) + fprintf_param_value(" 2=%d", affine) + + fwrite_weight_data(op->gamma_data, bp); + } else if (layer->type == "RNN") { ncnn::RNN* op = (ncnn::RNN*)layer; diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 2c814bd486cd..7743a8ae453e 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -369,6 +369,7 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_pixel_unshuffle.cpp pass_level5/fuse_layernorm.cpp pass_level5/fuse_multiheadattention.cpp + pass_level5/fuse_rmsnorm.cpp pass_level5/fuse_scaled_dot_product_attention.cpp pass_level5/fuse_select_to_unbind.cpp pass_level5/fuse_silu.cpp diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 8bb3270aa2c3..5f08b80f5ef9 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -44,6 +44,7 @@ #include "pass_level5/fuse_multiheadattention.h" #include "pass_level5/fuse_pad_conv1d.h" #include "pass_level5/fuse_pad_conv2d.h" +#include "pass_level5/fuse_rmsnorm.h" #include "pass_level5/fuse_scaled_dot_product_attention.h" #include "pass_level5/fuse_select_to_unbind.h" #include "pass_level5/fuse_silu.h" @@ -145,6 +146,7 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_channel_shuffle(g); fuse_layernorm(g); + fuse_rmsnorm(g); fuse_multiheadattention(g); fuse_scaled_dot_product_attention(g); diff --git a/tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp b/tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp new file mode 100644 index 000000000000..7b99770ed6ed --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp @@ -0,0 +1,97 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_rmsnorm.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_rmsnorm_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32 +pnnx.Expression op_1 1 1 input sq expr=pow(@0,2) +torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True +pnnx.Expression op_3 3 1 weight input sqmean out expr=mul(@0,mul(@1,rsqrt(add(@2,%eps)))) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.RMSNorm rmsnorm 1 1 input out elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_rmsnorm_pass_1 : public fuse_rmsnorm_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32 +pnnx.Expression op_1 1 1 input sq expr=pow(@0,2.000000e+00) +torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True +pnnx.Expression op_3 3 1 weight input sqmean out expr=mul(@0,mul(@1,reciprocal(sqrt(add(@2,%eps))))) +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_rmsnorm_pass_onnx : public fuse_rmsnorm_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32 +pnnx.Expression op_1 1 1 input sq expr=pow(@0,2.000000e+00) +torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True +pnnx.Expression op_3 3 1 weight input sqmean out expr=mul(@0,mul(@1,div(1.000000e+00,sqrt(add(@2,%eps))))) +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +void fuse_rmsnorm(Graph& graph) +{ + fuse_rmsnorm_pass a; + fuse_rmsnorm_pass_1 a1; + fuse_rmsnorm_pass_onnx b; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &a1, opindex); + pnnx_graph_rewrite(graph, &b, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_rmsnorm.h b/tools/pnnx/src/pass_level5/fuse_rmsnorm.h new file mode 100644 index 000000000000..0ba18e37f61b --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_rmsnorm.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_rmsnorm(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index daf5501e9d8b..0dd566c37b58 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -346,6 +346,7 @@ pnnx_add_test(pnnx_fuse_input_unpack) pnnx_add_test(pnnx_fuse_layernorm) pnnx_add_test(pnnx_fuse_linear_batchnorm1d) pnnx_add_test(pnnx_fuse_multiheadattention) +pnnx_add_test(pnnx_fuse_rmsnorm) pnnx_add_test(pnnx_fuse_scaled_dot_product_attention) pnnx_add_test(pnnx_fuse_select_to_unbind) pnnx_add_test(pnnx_fuse_slice_to_tensor_split) diff --git a/tools/pnnx/tests/ncnn/test_F_rms_norm.py b/tools/pnnx/tests/ncnn/test_F_rms_norm.py index 4e60d9314aae..f30f72f9ac45 100644 --- a/tools/pnnx/tests/ncnn/test_F_rms_norm.py +++ b/tools/pnnx/tests/ncnn/test_F_rms_norm.py @@ -57,7 +57,7 @@ def test(): b = test_F_rms_norm_ncnn.test_inference() for a0, b0 in zip(a, b): - if not torch.allclose(a0, b0, 1e-4, 1e-4): + if not torch.allclose(a0, b0, 1e-3, 1e-3): return False return True diff --git a/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py b/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py index 0d5efa211e4d..e69ad1220bc1 100644 --- a/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py +++ b/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py @@ -57,7 +57,7 @@ def test(): b = test_nn_RMSNorm_ncnn.test_inference() for a0, b0 in zip(a, b): - if not torch.allclose(a0, b0, 1e-4, 1e-4): + if not torch.allclose(a0, b0, 1e-3, 1e-3): return False return True diff --git a/tools/pnnx/tests/test_pnnx_fuse_rmsnorm.py b/tools/pnnx/tests/test_pnnx_fuse_rmsnorm.py new file mode 100644 index 000000000000..b04fa93442fa --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_rmsnorm.py @@ -0,0 +1,77 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.rand(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * x + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.rmsn_0 = T5LayerNorm(26) + self.rmsn_1 = T5LayerNorm(21) + + def forward(self, x, y): + x = self.rmsn_0(x) + y = self.rmsn_1(y) + return x, y + +def test(): + if version.parse(torch.__version__) < version.parse('2.4'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 64, 26) + y = torch.rand(3, 15, 15, 21) + + a0, a1 = net(x, y) + + # export onnx + torch.onnx.export(net, (x,y), "test.onnx") + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_pnnx_fuse_rmsnorm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_rmsnorm.pt inputshape=[1,64,26],[3,15,15,21]") + + # pnnx inference + import test_pnnx_fuse_rmsnorm_pnnx + b0, b1 = test_pnnx_fuse_rmsnorm_pnnx.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)