From 5df5413c81312b0106fe18066b47e2917afabd27 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 2 Sep 2024 18:48:01 +0800 Subject: [PATCH 1/4] embed int8 quantization and add embed test (#5667) --- .ci/pnnx.yml | 2 + docs/developer-guide/operators.md | 2 + src/layer/embed.cpp | 88 +++++++++++++++++++++--- src/layer/embed.h | 6 ++ tests/CMakeLists.txt | 1 + tests/test_embed.cpp | 108 ++++++++++++++++++++++++++++++ tools/modelwriter.h | 11 +++ tools/quantize/ncnn2int8.cpp | 52 ++++++++++++++ 8 files changed, 261 insertions(+), 9 deletions(-) create mode 100644 tests/test_embed.cpp diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 990690e0c5b7..d49da39a0afc 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -4,12 +4,14 @@ on: branches: [master] paths: - '.ci/pnnx.yml' + - 'src/layer/*' - 'tools/pnnx/**' - '!tools/pnnx/README.md' mr: target-branches: [master] paths: - '.ci/pnnx.yml' + - 'src/layer/*' - 'tools/pnnx/**' - '!tools/pnnx/README.md' concurrency: diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 7594c0843acb..de4d6b428e99 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -837,11 +837,13 @@ y = embedding(x) | 1 | input_dim | int | 0 | | | 2 | bias_term | int | 0 | | | 3 | weight_data_size | int | 0 | | +| 18 | int8_scale_term| int | 0 | | | weight | type | shape | | ------------- | ----- | --------------------- | | weight_data | float | [weight_data_size] | | bias_term | float | [num_output] | +| weight_data_int8_scales| float | [1] | # Exp ``` diff --git a/src/layer/embed.cpp b/src/layer/embed.cpp index ddda6b8bf199..2b9f8a60042c 100644 --- a/src/layer/embed.cpp +++ b/src/layer/embed.cpp @@ -30,6 +30,7 @@ int Embed::load_param(const ParamDict& pd) input_dim = pd.get(1, 0); bias_term = pd.get(2, 0); weight_data_size = pd.get(3, 0); + int8_scale_term = pd.get(18, 0); return 0; } @@ -47,18 +48,23 @@ int Embed::load_model(const ModelBin& mb) return -100; } +#if NCNN_INT8 + if (int8_scale_term) + { + weight_data_int8_scale = mb.load(1, 1)[0]; + } +#endif // NCNN_INT8 + return 0; } -int Embed::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +static void embed(const Mat& bottom_blob, const Mat& weight_data, const Mat& bias_data, Mat& top_blob, int input_dim, const Option& opt) { - int words = static_cast(bottom_blob.total()); + const int num_output = top_blob.w; + const int words = top_blob.h; - top_blob.create(num_output, words, 4u, opt.blob_allocator); - if (top_blob.empty()) - return -100; + const float* bias_ptr = bias_data; - // num_output #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < words; q++) { @@ -73,15 +79,79 @@ int Embed::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) con const float* em = (const float*)weight_data + num_output * word_index; - memcpy(outptr, em, num_output * sizeof(float)); + if (bias_ptr) + { + for (int p = 0; p < num_output; p++) + { + outptr[p] = em[p] + bias_ptr[p]; + } + } + else + { + memcpy(outptr, em, num_output * sizeof(float)); + } + } +} + +#if NCNN_INT8 +static void embed_int8(const Mat& bottom_blob, const Mat& weight_data, float weight_data_int8_scale, const Mat& bias_data, Mat& top_blob, int input_dim, const Option& opt) +{ + const int num_output = top_blob.w; + const int words = top_blob.h; + + const float* bias_ptr = bias_data; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < words; q++) + { + float* outptr = top_blob.row(q); + + int word_index = ((const int*)bottom_blob)[q]; - if (bias_term) + if (word_index < 0) + word_index = 0; + if (word_index >= input_dim) + word_index = input_dim - 1; + + const float descale_em = 1.f / weight_data_int8_scale; + + const signed char* em = (const signed char*)weight_data + num_output * word_index; + + if (bias_ptr) { for (int p = 0; p < num_output; p++) { - outptr[p] += bias_data[p]; + outptr[p] = em[p] * descale_em + bias_ptr[p]; } } + else + { + for (int p = 0; p < num_output; p++) + { + outptr[p] = em[p] * descale_em; + } + } + } +} +#endif // NCNN_INT8 + +int Embed::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int words = static_cast(bottom_blob.total()); + + top_blob.create(num_output, words, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + +#if NCNN_INT8 + if (int8_scale_term) + { + embed_int8(bottom_blob, weight_data, weight_data_int8_scale, bias_data, top_blob, input_dim, opt); + } + else +#endif // NCNN_INT8 + { + embed(bottom_blob, weight_data, bias_data, top_blob, input_dim, opt); } return 0; diff --git a/src/layer/embed.h b/src/layer/embed.h index 8e2366567163..b94c2b17bee4 100644 --- a/src/layer/embed.h +++ b/src/layer/embed.h @@ -38,9 +38,15 @@ class Embed : public Layer int weight_data_size; + int int8_scale_term; + // model Mat weight_data; Mat bias_data; + +#if NCNN_INT8 + float weight_data_int8_scale; +#endif }; } // namespace ncnn diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6c8939fc7c7e..e2ddc32a00dc 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -101,6 +101,7 @@ ncnn_add_layer_test(Dropout) ncnn_add_layer_test(Einsum) ncnn_add_layer_test(Eltwise) ncnn_add_layer_test(ELU) +ncnn_add_layer_test(Embed) ncnn_add_layer_test(Erf) ncnn_add_layer_test(ExpandDims) ncnn_add_layer_test(Flatten) diff --git a/tests/test_embed.cpp b/tests/test_embed.cpp new file mode 100644 index 000000000000..9c007ee5d7e7 --- /dev/null +++ b/tests/test_embed.cpp @@ -0,0 +1,108 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +static int test_embed(int words, int num_output, int input_dim, int bias) +{ + ncnn::ParamDict pd; + pd.set(0, num_output); + pd.set(1, input_dim); + pd.set(2, bias); + pd.set(3, num_output * input_dim); + + std::vector weights(bias ? 2 : 1); + weights[0] = RandomMat(num_output * input_dim); + if (bias) + weights[1] = RandomMat(num_output); + + ncnn::Mat a(words); + RandomizeInt(a, 0, input_dim); + + int ret = test_layer("Embed", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_embed failed words=%d num_output=%d input_dim=%d bias=%d\n", words, num_output, input_dim, bias); + } + + return ret; +} + +static int test_embed_0() +{ + return 0 + || test_embed(128, 128, 128, 0) + || test_embed(128, 128, 128, 1) + || test_embed(127, 127, 127, 0) + || test_embed(127, 127, 127, 1) + || test_embed(124, 124, 124, 0) + || test_embed(124, 124, 124, 1); +} + +#if NCNN_INT8 +static int test_embed_int8(int words, int num_output, int input_dim, int bias) +{ + ncnn::ParamDict pd; + pd.set(0, num_output); + pd.set(1, input_dim); + pd.set(2, bias); + pd.set(3, num_output * input_dim); + pd.set(18, 2); + + std::vector weights(bias ? 3 : 2); + weights[0] = RandomS8Mat(num_output * input_dim); + if (bias) + { + weights[1] = RandomMat(num_output); + weights[2] = RandomMat(1, 100.f, 200.f); + } + else + { + weights[1] = RandomMat(1, 100.f, 200.f); + } + + ncnn::Mat a(words); + RandomizeInt(a, 0, input_dim); + + int ret = test_layer("Embed", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_embed_int8 failed words=%d num_output=%d input_dim=%d bias=%d\n", words, num_output, input_dim, bias); + } + + return ret; +} + +static int test_embed_1() +{ + return 0 + || test_embed_int8(128, 128, 128, 0) + || test_embed_int8(128, 128, 128, 1) + || test_embed_int8(127, 127, 127, 0) + || test_embed_int8(127, 127, 127, 1) + || test_embed_int8(124, 124, 124, 0) + || test_embed_int8(124, 124, 124, 1); +} +#endif // NCNN_INT8 + +int main() +{ + SRAND(7767517); + +#if NCNN_INT8 + return test_embed_0() || test_embed_1(); +#else + return test_embed_0(); +#endif +} diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 4f445cfe2a4d..39157c453ece 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -1676,9 +1676,20 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 1=%d", input_dim) fprintf_param_value(" 2=%d", bias_term) fprintf_param_value(" 3=%d", weight_data_size) + fprintf_param_value(" 18=%d", int8_scale_term) fwrite_weight_tag_data(op->weight_data, bp); fwrite_weight_data(op->bias_data, bp); + +#if NCNN_INT8 + // write int8_scale data + if (op->int8_scale_term) + { + ncnn::Mat weight_data_int8_scales(1); + weight_data_int8_scales[0] = op->weight_data_int8_scale; + fwrite_weight_data(weight_data_int8_scales, bp, 90, 100); + } +#endif // NCNN_INT8 } else if (layer->type == "Exp") { diff --git a/tools/quantize/ncnn2int8.cpp b/tools/quantize/ncnn2int8.cpp index 4d19ceb6f166..5e92b333aa57 100644 --- a/tools/quantize/ncnn2int8.cpp +++ b/tools/quantize/ncnn2int8.cpp @@ -133,6 +133,8 @@ class NetQuantize : public ModelWriter int quantize_lstm(); int quantize_gru(); + int quantize_embed(); + int fuse_requantize(); }; @@ -562,6 +564,55 @@ int NetQuantize::quantize_gru() return 0; } +int NetQuantize::quantize_embed() +{ + for (size_t i = 0; i < layers.size(); i++) + { + if (layers[i]->type != "Embed") + continue; + + // Embed - quantize weight from fp32 to int8 + ncnn::Embed* embed = (ncnn::Embed*)layers[i]; + + fprintf(stderr, "quantize_embed %s\n", embed->name.c_str()); + + // TODO move to ncnn2table + + const int num_output = embed->num_output; + const int input_dim = embed->input_dim; + + ncnn::Mat weight_data_int8_scales(1); + { + const float* ptr = embed->weight_data; + float absmax = 0.f; + for (int i = 0; i < embed->weight_data.w; i++) + { + absmax = std::max(absmax, (float)fabs(ptr[i])); + } + + weight_data_int8_scales[0] = absmax == 0.f ? 1.f : 127 / absmax; + } + + { + ncnn::Mat weight_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = embed->weight_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(embed->weight_data, weight_data_int8, weight_data_int8_scales, opt_q); + if (weight_data_int8.empty()) + return -100; + + embed->weight_data = weight_data_int8; + } + + embed->int8_scale_term = 2; + embed->weight_data_int8_scale = weight_data_int8_scales[0]; + } + + return 0; +} + int NetQuantize::fuse_requantize() { const size_t layer_count = layers.size(); @@ -809,6 +860,7 @@ int main(int argc, char** argv) quantizer.quantize_rnn(); quantizer.quantize_lstm(); quantizer.quantize_gru(); + quantizer.quantize_embed(); quantizer.fuse_requantize(); From 8077d340a905ff4b15f7c266da85c811983e6291 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 3 Sep 2024 17:16:50 +0800 Subject: [PATCH 2/4] arm neon optimzation for rmsnorm (#5668) --- src/layer/arm/rmsnorm_arm.cpp | 417 ++++++++++++++++++++++++++ src/layer/arm/rmsnorm_arm.h | 40 +++ src/layer/arm/rmsnorm_arm_asimdhp.cpp | 272 +++++++++++++++++ 3 files changed, 729 insertions(+) create mode 100644 src/layer/arm/rmsnorm_arm.cpp create mode 100644 src/layer/arm/rmsnorm_arm.h create mode 100644 src/layer/arm/rmsnorm_arm_asimdhp.cpp diff --git a/src/layer/arm/rmsnorm_arm.cpp b/src/layer/arm/rmsnorm_arm.cpp new file mode 100644 index 000000000000..e19136ca29d6 --- /dev/null +++ b/src/layer/arm/rmsnorm_arm.cpp @@ -0,0 +1,417 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "rmsnorm_arm.h" + +#if __ARM_NEON +#include +#endif // __ARM_NEON + +#include "arm_usability.h" +#include "cpu.h" + +namespace ncnn { + +RMSNorm_arm::RMSNorm_arm() +{ +#if __ARM_NEON + support_packing = true; +#if NCNN_ARM82 + support_fp16_storage = cpu_support_arm_asimdhp(); +#endif +#endif // __ARM_NEON + +#if NCNN_BF16 + support_bf16_storage = true; +#endif +} + +static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __ARM_NEON + float32x4_t _rms = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float rms = 0.f; + { + const float* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr0); + _rms = vmlaq_f32(_rms, _p, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + rms += ptr0[0] * ptr0[0]; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + +#if __aarch64__ + _rms = vdivq_f32(_rms, _elemcount); + _rms = vaddq_f32(_rms, _eps); +#else + float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount); + _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); + _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); + _rms = vmlaq_f32(_eps, _rms, _inv_elemcount); +#endif + + float32x4_t _rsqrt_rms = vrsqrteq_f32(_rms); + _rsqrt_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + _rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + rms += vaddvq_f32(_rms); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_rms), vget_high_f32(_rms)); + _s2 = vpadd_f32(_s2, _s2); + rms += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + rms = 1.f / sqrtf(rms / elemcount + eps); +#if __ARM_NEON + _rms = vdupq_n_f32(rms); +#endif // __ARM_NEON + } + + if (gamma_ptr) + { + int i = 0; +#if __ARM_NEON + if (elempack == 4) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + _p = vmulq_f32(_p, _rms); + _p = vmulq_f32(_p, _gamma); + vst1q_f32(ptr, _p); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + _p = vmulq_f32(_p, _rms); + _p = vmulq_f32(_p, _gamma); + vst1q_f32(ptr, _p); + ptr += 4; + gamma_ptr += 4; + } + } +#endif // __ARM_NEON + for (; i < size; i++) + { + ptr[0] = (ptr[0] * rms) * gamma_ptr[0]; + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + _p = vmulq_f32(_p, _rms); + vst1q_f32(ptr, _p); + ptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + ptr[0] = ptr[0] * rms; + ptr++; + } + } +} + +int RMSNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int elembits = bottom_top_blob.elembits(); + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + return forward_inplace_fp16s(bottom_top_blob, opt); +#endif + +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); +#endif + + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + float* ptr = bottom_top_blob; + rmsnorm(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.row(i); + rmsnorm(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.channel(q).row(i); + rmsnorm(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + rmsnorm(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} + +#if NCNN_BF16 +static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __ARM_NEON + float32x4_t _rms = vdupq_n_f32(0.f); +#endif // __ARM_NEON + float rms = 0.f; + { + const unsigned short* ptr0 = ptr; + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr0)); + _rms = vmlaq_f32(_rms, _p, _p); + ptr0 += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr0[0]); + rms += v * v; + ptr0++; + } + } + +#if __ARM_NEON + if (elempack == 4) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + +#if __aarch64__ + _rms = vdivq_f32(_rms, _elemcount); + _rms = vaddq_f32(_rms, _eps); +#else + float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount); + _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); + _inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount); + _rms = vmlaq_f32(_eps, _rms, _inv_elemcount); +#endif + + float32x4_t _rsqrt_rms = vrsqrteq_f32(_rms); + _rsqrt_rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + _rms = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms, _rsqrt_rms), _rsqrt_rms), _rsqrt_rms); + } +#endif // __ARM_NEON + if (elempack == 1) + { +#if __ARM_NEON +#if __aarch64__ + rms += vaddvq_f32(_rms); +#else + float32x2_t _s2 = vadd_f32(vget_low_f32(_rms), vget_high_f32(_rms)); + _s2 = vpadd_f32(_s2, _s2); + rms += vget_lane_f32(_s2, 0); +#endif +#endif // __ARM_NEON + + rms = 1.f / sqrtf(rms / elemcount + eps); +#if __ARM_NEON + _rms = vdupq_n_f32(rms); +#endif // __ARM_NEON + } + + if (gamma_ptr) + { + int i = 0; +#if __ARM_NEON + if (elempack == 4) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + _p = vmulq_f32(_p, _rms); + _p = vmulq_f32(_p, _gamma); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + _p = vmulq_f32(_p, _rms); + _p = vmulq_f32(_p, _gamma); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + gamma_ptr += 4; + } + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr[0]); + ptr[0] = float32_to_bfloat16((v * rms) * gamma_ptr[0]); + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + _p = vmulq_f32(_p, _rms); + vst1_u16(ptr, float2bfloat(_p)); + ptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + float v = bfloat16_to_float32(ptr[0]); + ptr[0] = float32_to_bfloat16(v * rms); + ptr++; + } + } +} + +int RMSNorm_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + unsigned short* ptr = bottom_top_blob; + rmsnorm_bf16s(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + rmsnorm_bf16s(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.channel(q).row(i); + rmsnorm_bf16s(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + rmsnorm_bf16s(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/arm/rmsnorm_arm.h b/src/layer/arm/rmsnorm_arm.h new file mode 100644 index 000000000000..440153333710 --- /dev/null +++ b/src/layer/arm/rmsnorm_arm.h @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_RMSNORM_ARM_H +#define LAYER_RMSNORM_ARM_H + +#include "rmsnorm.h" + +namespace ncnn { + +class RMSNorm_arm : public RMSNorm +{ +public: + RMSNorm_arm(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if NCNN_ARM82 + int forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const; +#endif +#if NCNN_BF16 + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; +#endif +}; + +} // namespace ncnn + +#endif // LAYER_RMSNORM_ARM_H diff --git a/src/layer/arm/rmsnorm_arm_asimdhp.cpp b/src/layer/arm/rmsnorm_arm_asimdhp.cpp new file mode 100644 index 000000000000..98d8e6964876 --- /dev/null +++ b/src/layer/arm/rmsnorm_arm_asimdhp.cpp @@ -0,0 +1,272 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "rmsnorm_arm.h" + +#if __ARM_NEON +#include +#include "arm_usability.h" +#endif // __ARM_NEON + +namespace ncnn { + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +static void rmsnorm_fp16s(__fp16* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + + float32x4_t _rms0 = vdupq_n_f32(0.f); + float32x4_t _rms1 = vdupq_n_f32(0.f); + float rms = 0.f; + { + const __fp16* ptr0 = ptr; + + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr0); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _rms0 = vmlaq_f32(_rms0, _p0, _p0); + _rms1 = vmlaq_f32(_rms1, _p1, _p1); + ptr0 += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr0)); + _rms0 = vmlaq_f32(_rms0, _p, _p); + ptr0 += 4; + } + for (; i < size; i++) + { + rms += (float)ptr0[0] * (float)ptr0[0]; + ptr0++; + } + } + + if (elempack == 8) + { + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + + _rms0 = vdivq_f32(_rms0, _elemcount); + _rms1 = vdivq_f32(_rms1, _elemcount); + _rms0 = vaddq_f32(_rms0, _eps); + _rms1 = vaddq_f32(_rms1, _eps); + + float32x4_t _rsqrt_rms0 = vrsqrteq_f32(_rms0); + float32x4_t _rsqrt_rms1 = vrsqrteq_f32(_rms1); + _rsqrt_rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rsqrt_rms1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms1, _rsqrt_rms1), _rsqrt_rms1), _rsqrt_rms1); + _rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms1, _rsqrt_rms1), _rsqrt_rms1), _rsqrt_rms1); + } + if (elempack == 4) + { + _rms0 = vaddq_f32(_rms0, _rms1); + + float32x4_t _elemcount = vdupq_n_f32(elemcount); + float32x4_t _eps = vdupq_n_f32(eps); + + _rms0 = vdivq_f32(_rms0, _elemcount); + _rms0 = vaddq_f32(_rms0, _eps); + + float32x4_t _rsqrt_rms0 = vrsqrteq_f32(_rms0); + _rsqrt_rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_rms0, _rsqrt_rms0), _rsqrt_rms0), _rsqrt_rms0); + _rms1 = _rms0; + } + if (elempack == 1) + { + _rms0 = vaddq_f32(_rms0, _rms1); + rms += vaddvq_f32(_rms0); + + rms = 1.f / sqrtf(rms / elemcount + eps); + _rms0 = vdupq_n_f32(rms); + _rms1 = _rms0; + } + + if (gamma_ptr) + { + int i = 0; + if (elempack == 8) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); + _p0 = vmulq_f32(_p0, _gamma); + _p1 = vmulq_f32(_p1, _gamma); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 1; + } + } + if (elempack == 4) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma0 = vdupq_n_f32(gamma_ptr[0]); + float32x4_t _gamma1 = vdupq_n_f32(gamma_ptr[1]); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); + _p0 = vmulq_f32(_p0, _gamma0); + _p1 = vmulq_f32(_p1, _gamma1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 2; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]); + _p = vmulq_f32(_p, _rms0); + _p = vmulq_f32(_p, _gamma); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + float32x4_t _gamma0 = vld1q_f32(gamma_ptr); + float32x4_t _gamma1 = vld1q_f32(gamma_ptr + 4); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); + _p0 = vmulq_f32(_p0, _gamma0); + _p1 = vmulq_f32(_p1, _gamma1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + gamma_ptr += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + float32x4_t _gamma = vld1q_f32(gamma_ptr); + _p = vmulq_f32(_p, _rms0); + _p = vmulq_f32(_p, _gamma); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + gamma_ptr += 4; + } + } + for (; i < size; i++) + { + ptr[0] = (__fp16)(((float)ptr[0] * rms) * gamma_ptr[0]); + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float32x4_t _p0 = vcvt_f32_f16(vget_low_f16(_p)); + float32x4_t _p1 = vcvt_f32_f16(vget_high_f16(_p)); + _p0 = vmulq_f32(_p0, _rms0); + _p1 = vmulq_f32(_p1, _rms1); + _p = vcombine_f16(vcvt_f16_f32(_p0), vcvt_f16_f32(_p1)); + vst1q_f16(ptr, _p); + ptr += 8; + } + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vmulq_f32(_p, _rms0); + vst1_f16(ptr, vcvt_f16_f32(_p)); + ptr += 4; + } + for (; i < size; i++) + { + ptr[0] = (__fp16)((float)ptr[0] * rms); + ptr++; + } + } +} + +int RMSNorm_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + __fp16* ptr = bottom_top_blob; + rmsnorm_fp16s(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + rmsnorm_fp16s(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.channel(q).row<__fp16>(i); + rmsnorm_fp16s(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + __fp16* ptr = bottom_top_blob.channel(q); + rmsnorm_fp16s(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +} // namespace ncnn From 204583ba52cbc1e4b39b4e77ee1b050eeb1734b7 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 3 Sep 2024 17:17:03 +0800 Subject: [PATCH 3/4] x86 sse2/avx/avx512 optimization for rmsnorm (#5672) --- src/layer/x86/rmsnorm_x86.cpp | 413 ++++++++++++++++++++++++++++++++++ src/layer/x86/rmsnorm_x86.h | 32 +++ 2 files changed, 445 insertions(+) create mode 100644 src/layer/x86/rmsnorm_x86.cpp create mode 100644 src/layer/x86/rmsnorm_x86.h diff --git a/src/layer/x86/rmsnorm_x86.cpp b/src/layer/x86/rmsnorm_x86.cpp new file mode 100644 index 000000000000..db592c3e3810 --- /dev/null +++ b/src/layer/x86/rmsnorm_x86.cpp @@ -0,0 +1,413 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "rmsnorm_x86.h" + +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +#include "x86_usability.h" + +namespace ncnn { + +RMSNorm_x86::RMSNorm_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + +static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack) +{ + const int size = elemcount * elempack; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _rms_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ + __m256 _rms_avx = _mm256_set1_ps(0.f); +#endif // __AVX__ + __m128 _rms = _mm_set1_ps(0.f); +#endif // __SSE2__ + float rms = 0.f; + { + const float* ptr0 = ptr; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr0); + _rms_avx512 = _mm512_fmadd_ps(_p, _p, _rms_avx512); + ptr0 += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr0); + _rms_avx = _mm256_comp_fmadd_ps(_p, _p, _rms_avx); + ptr0 += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr0); + _rms = _mm_comp_fmadd_ps(_p, _p, _rms); + ptr0 += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + rms += ptr0[0] * ptr0[0]; + ptr0++; + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _elemcount = _mm512_set1_ps((float)elemcount); + __m512 _eps = _mm512_set1_ps(eps); + + _rms_avx512 = _mm512_div_ps(_rms_avx512, _elemcount); + _rms_avx512 = _mm512_add_ps(_rms_avx512, _eps); + + __m256 _rms0 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_rms_avx512, 0)); + __m256 _rms1 = _mm256_rsqrt_ps(_mm512_extractf32x8_ps(_rms_avx512, 1)); + _rms_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_rms0), _rms1, 1); + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + { + __m256 _rms0 = _mm512_castps512_ps256(_rms_avx512); + __m256 _rms1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_rms_avx512), 1)); + _rms_avx = _mm256_add_ps(_rms_avx, _rms0); + _rms_avx = _mm256_add_ps(_rms_avx, _rms1); + } +#endif // __AVX512F__ + + __m256 _elemcount = _mm256_set1_ps((float)elemcount); + __m256 _eps = _mm256_set1_ps(eps); + + _rms_avx = _mm256_div_ps(_rms_avx, _elemcount); + _rms_avx = _mm256_add_ps(_rms_avx, _eps); + + _rms_avx = _mm256_rsqrt_ps(_rms_avx); +#if __AVX512F__ + _rms_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_rms_avx), _rms_avx, 1); +#endif // __AVX512F__ + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + { + __m256 _rms0 = _mm512_castps512_ps256(_rms_avx512); + __m256 _rms1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_rms_avx512), 1)); + _rms_avx = _mm256_add_ps(_rms_avx, _rms0); + _rms_avx = _mm256_add_ps(_rms_avx, _rms1); + } +#endif // __AVX512F__ + { + __m128 _rms0 = _mm256_castps256_ps128(_rms_avx); + __m128 _rms1 = _mm256_extractf128_ps(_rms_avx, 1); + _rms = _mm_add_ps(_rms, _rms0); + _rms = _mm_add_ps(_rms, _rms1); + } +#endif // __AVX__ + + __m128 _elemcount = _mm_set1_ps((float)elemcount); + __m128 _eps = _mm_set1_ps(eps); + + _rms = _mm_div_ps(_rms, _elemcount); + _rms = _mm_add_ps(_rms, _eps); + + _rms = _mm_rsqrt_ps(_rms); +#if __AVX__ + _rms_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_rms), _rms, 1); +#if __AVX512F__ + _rms_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_rms_avx), _rms_avx, 1); +#endif // __AVX512F__ +#endif // __AVX__ + } +#endif // __SSE2__ + if (elempack == 1) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + rms += _mm512_comp_reduce_add_ps(_rms_avx512); +#endif // __AVX512F__ + rms += _mm256_reduce_add_ps(_rms_avx); +#endif // __AVX__ + rms += _mm_reduce_add_ps(_rms); +#endif // __SSE2__ + + rms = 1.f / sqrtf(rms / elemcount + eps); +#if __SSE2__ + _rms = _mm_set1_ps(rms); +#if __AVX__ + _rms_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_rms), _rms, 1); +#if __AVX512F__ + _rms_avx512 = _mm512_insertf32x8(_mm512_castps256_ps512(_rms_avx), _rms_avx, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + } + + if (gamma_ptr) + { + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_set1_ps(gamma_ptr[0]); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 1; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m256 _gamma0 = _mm256_set1_ps(gamma_ptr[0]); + __m256 _gamma1 = _mm256_set1_ps(gamma_ptr[1]); + __m512 _gamma = _mm512_insertf32x8(_mm512_castps256_ps512(_gamma0), _gamma1, 1); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 2; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_set1_ps(gamma_ptr[0]); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 1; + } + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m128 _gamma2 = _mm_set1_ps(gamma_ptr[2]); + __m128 _gamma3 = _mm_set1_ps(gamma_ptr[3]); + __m256 _gamma01 = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma0), _gamma1, 1); + __m256 _gamma23 = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma2), _gamma3, 1); + __m512 _gamma = _mm512_insertf32x8(_mm512_castps256_ps512(_gamma01), _gamma23, 1); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 4; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m128 _gamma0 = _mm_set1_ps(gamma_ptr[0]); + __m128 _gamma1 = _mm_set1_ps(gamma_ptr[1]); + __m256 _gamma = _mm256_insertf128_ps(_mm256_castps128_ps256(_gamma0), _gamma1, 1); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 2; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_set1_ps(gamma_ptr[0]); + _p = _mm_mul_ps(_p, _rms); + _p = _mm_mul_ps(_p, _gamma); + _mm_storeu_ps(ptr, _p); + ptr += 4; + gamma_ptr += 1; + } + } + if (elempack == 1) + { +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _gamma = _mm512_loadu_ps(gamma_ptr); + _p = _mm512_mul_ps(_p, _rms_avx512); + _p = _mm512_mul_ps(_p, _gamma); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + gamma_ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _gamma = _mm256_loadu_ps(gamma_ptr); + _p = _mm256_mul_ps(_p, _rms_avx); + _p = _mm256_mul_ps(_p, _gamma); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + gamma_ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _gamma = _mm_loadu_ps(gamma_ptr); + _p = _mm_mul_ps(_p, _rms); + _p = _mm_mul_ps(_p, _gamma); + _mm_storeu_ps(ptr, _p); + ptr += 4; + gamma_ptr += 4; + } + } +#endif // __SSE2__ + for (; i < size; i++) + { + ptr[0] = (ptr[0] * rms) * gamma_ptr[0]; + ptr++; + gamma_ptr++; + } + } + else + { + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _p = _mm512_mul_ps(_p, _rms_avx512); + _mm512_storeu_ps(ptr, _p); + ptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + _p = _mm256_mul_ps(_p, _rms_avx); + _mm256_storeu_ps(ptr, _p); + ptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + _p = _mm_mul_ps(_p, _rms); + _mm_storeu_ps(ptr, _p); + ptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + ptr[0] = ptr[0] * rms; + ptr++; + } + } +} + +int RMSNorm_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + const int dims = bottom_top_blob.dims; + const int w = bottom_top_blob.w; + const int h = bottom_top_blob.h; + const int channels = bottom_top_blob.c; + const int elempack = bottom_top_blob.elempack; + + if (dims == 1) + { + // assert affine_size == w + + float* ptr = bottom_top_blob; + rmsnorm(ptr, gamma_data, eps, w * elempack, 1); + } + + if (dims == 2) + { + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.row(i); + rmsnorm(ptr, gamma_data, eps, w, elempack); + } + } + + if (dims == 3) + { + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.channel(q).row(i); + rmsnorm(ptr, gamma_data, eps, w, elempack); + } + } + } + else // if (affine_size == w * h) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + rmsnorm(ptr, gamma_data, eps, w * h, elempack); + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/rmsnorm_x86.h b/src/layer/x86/rmsnorm_x86.h new file mode 100644 index 000000000000..2e6296db1c32 --- /dev/null +++ b/src/layer/x86/rmsnorm_x86.h @@ -0,0 +1,32 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_RMSNORM_X86_H +#define LAYER_RMSNORM_X86_H + +#include "rmsnorm.h" + +namespace ncnn { + +class RMSNorm_x86 : public RMSNorm +{ +public: + RMSNorm_x86(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_RMSNORM_X86_H From 21e54d8c7a789884d1c17dc1b40701bede343975 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 4 Sep 2024 08:01:53 +0800 Subject: [PATCH 4/4] update modelwriter for rmsnorm (#5676) --- tools/modelwriter.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 39157c453ece..ff86338bca9c 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -99,6 +99,7 @@ #include "layer/reorg.h" #include "layer/requantize.h" #include "layer/reshape.h" +#include "layer/rmsnorm.h" #include "layer/rnn.h" #include "layer/roialign.h" #include "layer/roipooling.h" @@ -2313,6 +2314,17 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 2=%d", c) fprintf_param_value(" 3=%d", permute) } + else if (layer->type == "RMSNorm") + { + ncnn::RMSNorm* op = (ncnn::RMSNorm*)layer; + ncnn::RMSNorm* op_default = (ncnn::RMSNorm*)layer_default; + + fprintf_param_value(" 0=%d", affine_size) + fprintf_param_value(" 1=%e", eps) + fprintf_param_value(" 2=%d", affine) + + fwrite_weight_data(op->gamma_data, bp); + } else if (layer->type == "RNN") { ncnn::RNN* op = (ncnn::RNN*)layer;