From 9c966762c68b1fac079861e8e7135bf32595a1a7 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Mon, 23 May 2022 16:24:23 +0800 Subject: [PATCH 01/11] add v9 build support --- .pre-commit-config.yaml | 2 +- CMakeLists.txt | 1 + cmake/configure.cmake | 4 + cmake/os/android.cmake | 16 +- cmake/postproject.cmake | 11 + lite/backends/arm/math/CMakeLists.txt | 10 +- lite/backends/arm/math/sve/funcs_sve.h | 152 ++++++++++ lite/backends/arm/math/sve/softmax_sve.cc | 328 ++++++++++++++++++++++ lite/backends/arm/math/sve/softmax_sve.h | 46 +++ lite/core/context.h | 1 + lite/core/device_info.cc | 63 +++++ lite/core/device_info.h | 5 + lite/kernels/arm/softmax_compute.cc | 30 ++ lite/tools/build.sh | 11 + 14 files changed, 677 insertions(+), 3 deletions(-) create mode 100644 lite/backends/arm/math/sve/funcs_sve.h create mode 100644 lite/backends/arm/math/sve/softmax_sve.cc create mode 100644 lite/backends/arm/math/sve/softmax_sve.h diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5ebedf10fa3..cca3e6daa6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: hooks: - id: copyright_checker name: copyright_checker - entry: python ./tools/codestyle/copyright.hook + entry: python3 ./tools/codestyle/copyright.hook language: system files: \.(c|cc|cxx|cpp|cu|cl|h|hpp|hxx|proto|py|mm|m|metal)$ exclude: (?!.*third_party)^.*$|(?!.*book)^.*$ diff --git a/CMakeLists.txt b/CMakeLists.txt index 594a06a6692..c11413f2616 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,6 +86,7 @@ lite_option(CUDA_WITH_FP16 "Compile with cuda half support" lite_option(LITE_WITH_ARM_CLANG "when arm lang is clang, its ON." OFF) lite_option(LITE_WITH_XCODE "when debug in xcode, its ON." OFF) lite_option(LITE_WITH_ARM82_FP16 "when compile with arm v8.2 fp16, it's ON." OFF) +lite_option(LITE_WITH_ARM8_SVE2 "Enable SVE2 instructions in ARMv8." OFF) lite_option(LITE_WITH_ARM82_INT8_SDOT "when compile with arm v8.2 int8, it's ON." OFF) lite_option(LITE_WITH_CODE_META_INFO "include git version in the header file." ON) lite_option(WITH_NODE_RAW_FS "(Only available when compiling by Emscripten) Whether build with NODERAWFS" OFF) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index e57155902c6..f2f85bcb6df 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -297,6 +297,10 @@ if (LITE_WITH_ARM82_FP16) add_definitions("-DLITE_WITH_ARM82_FP16") endif(LITE_WITH_ARM82_FP16) +if (LITE_WITH_ARM8_SVE2) + add_definitions("-DLITE_WITH_ARM82_FP16") +endif(LITE_WITH_ARM8_SVE2) + if (LITE_WITH_M1) add_definitions("-DLITE_WITH_M1") endif(LITE_WITH_M1) diff --git a/cmake/os/android.cmake b/cmake/os/android.cmake index f813bc02d5a..ff988ea32f8 100644 --- a/cmake/os/android.cmake +++ b/cmake/os/android.cmake @@ -97,7 +97,21 @@ if(ARM_TARGET_LANG STREQUAL "clang") set(triple aarch64-v8a-linux-android) if(ANDROID_STL_TYPE MATCHES "^c\\+\\+_") # Use CMAKE_CXX_STANDARD_LIBRARIES_INIT to ensure libunwind and libc++ is linked in the right order - set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libunwind.a") + # set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libunwind.a") + set(LIBUNWIND_PATH "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libunwind.a") + if(EXISTS ${LIBUNWIND_PATH}) + message(STATUS "libunwind is in ${LIBUNWIND_PATH}") + else() + # happened when NDK >= 23 + file(GLOB_RECURSE WIND_PATH "${CMAKE_ANDROID_NDK}/*/libunwind.a") + foreach(loop_path ${WIND_PATH}) + string(FIND ${loop_path} "aarch64" STR_END) + string(SUBSTRING ${loop_path} 0 ${STR_END} REAL_LIBUNWIND_PATH) + break() + endforeach() + set(LIBUNWIND_PATH "${REAL_LIBUNWIND_PATH}aarch64/libunwind.a") + endif() + set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${LIBUNWIND_PATH}") if (ANDROID_NATIVE_API_LEVEL LESS 21) set(CMAKE_CXX_STANDARD_LIBRARIES_INIT "${CMAKE_CXX_STANDARD_LIBRARIES_INIT} ${ANDROID_NDK}/sources/cxx-stl/llvm-libc++/libs/${ANDROID_ARCH_ABI}/libandroid_support.a") endif() diff --git a/cmake/postproject.cmake b/cmake/postproject.cmake index fad675a9c55..d098b2c893e 100644 --- a/cmake/postproject.cmake +++ b/cmake/postproject.cmake @@ -41,6 +41,17 @@ if(ANDROID) endif() endif() + if(LITE_WITH_ARM8_SVE2) + if(${ANDROID_NDK_MAJOR}) + if(${ANDROID_NDK_MAJOR} GREATER_EQUAL "23") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+sve2") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+sve2") + else() + message(FATAL_ERROR "NDK VERSION: ${ANDROID_NDK_MAJOR}, however it must be greater equal 23 when sve2 is ON") + endif() + endif() + endif() + if(LITE_WITH_ARM82_INT8_SDOT) if(${ANDROID_NDK_MAJOR}) if(${ANDROID_NDK_MAJOR} GREATER "17") diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index a79890c8201..e1fb40438a5 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -29,8 +29,16 @@ endif () FILE(GLOB ARM_MATH_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) # fp16 arm math source code in fp16/ directory FILE(GLOB FP16_ARM_MATH_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) - +# sve2 arm math source code in sve2/ directory +FILE(GLOB SVE2_ARM_MATH_SRC ${CMAKE_CURRENT_SOURCE_DIR}/sve2/*.cc) +FILE(GLOB SVE_ARM_MATH_SRC ${CMAKE_CURRENT_SOURCE_DIR}/sve/*.cc) if(LITE_WITH_ARM82_FP16) set(ARM_MATH_SRC ${ARM_MATH_SRC} ${FP16_ARM_MATH_SRC}) endif() + +if(LITE_WITH_ARM8_SVE2) + set(ARM_MATH_SRC ${ARM_MATH_SRC} ${SVE_ARM_MATH_SRC}) + set(ARM_MATH_SRC ${ARM_MATH_SRC} ${SVE2_ARM_MATH_SRC}) +endif() + lite_cc_library(math_arm SRCS ${ARM_MATH_SRC}) diff --git a/lite/backends/arm/math/sve/funcs_sve.h b/lite/backends/arm/math/sve/funcs_sve.h new file mode 100644 index 00000000000..e16cc69899f --- /dev/null +++ b/lite/backends/arm/math/sve/funcs_sve.h @@ -0,0 +1,152 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/backends/arm/math/sve/softmax_sve.h" + +#ifdef ENABLE_ARM_FP16 +#include "lite/backends/arm/math/fp16/funcs_fp16.h" +typedef __fp16 float16_t; +#endif +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +namespace sve { + +template +inline svbool_t svptrue_size(); + +template <> +inline svbool_t svptrue_size<64>() { + return svptrue_b64(); +} + +template <> +inline svbool_t svptrue_size<32>() { + return svptrue_b32(); +} + +template <> +inline svbool_t svptrue_size<16>() { + return svptrue_b16(); +} + +template <> +inline svbool_t svptrue_size<8>() { + return svptrue_b8(); +} + +template +svbool_t svptrue() { + return svptrue_size(); +} + +template +inline uint64_t svcnt_size(); + +template <> +inline uint64_t svcnt_size<64>() { + return svcntd(); +} + +template <> +inline uint64_t svcnt_size<32>() { + return svcntw(); +} + +template <> +inline uint64_t svcnt_size<16>() { + return svcnth(); +} + +template <> +inline uint64_t svcnt_size<8>() { + return svcntb(); +} + +template +inline uint64_t svcnt() { + return svcnt_size(); +} + +#define SVDUP_N_IMPL(Intype, Vectortype, postfix) \ + inline Vectortype svdup_n(Intype a) { return svdup_n_##postfix(a); } + +SVDUP_N_IMPL(int8_t, svint8_t, s8) +SVDUP_N_IMPL(int16_t, svint16_t, s16) +SVDUP_N_IMPL(int32_t, svint32_t, s32) +SVDUP_N_IMPL(int64_t, svint64_t, s64) +SVDUP_N_IMPL(uint8_t, svuint8_t, u8) +SVDUP_N_IMPL(uint16_t, svuint16_t, u16) +SVDUP_N_IMPL(uint32_t, svuint32_t, u32) +SVDUP_N_IMPL(uint64_t, svuint64_t, u64) +SVDUP_N_IMPL(float16_t, svfloat16_t, f16) +SVDUP_N_IMPL(float, svfloat32_t, f32) +SVDUP_N_IMPL(bfloat16_t, svbfloat16_t, bf16) + +#undef SVDUP_N_IMPL + +#define SVWHILELT_IMPL(type) \ + template \ + inline svbool_t svwhilelt_size(type a, type b); \ + template <> \ + inline svbool_t svwhilelt_size<64>(type a, type b) { \ + return svwhilelt_b64(a, b); \ + } \ + template <> \ + inline svbool_t svwhilelt_size<32>(type a, type b) { \ + return svwhilelt_b32(a, b); \ + } \ + template <> \ + inline svbool_t svwhilelt_size<16>(type a, type b) { \ + return svwhilelt_b16(a, b); \ + } \ + template <> \ + inline svbool_t svwhilelt_size<8>(type a, type b) { \ + return svwhilelt_b8(a, b); \ + } + +SVWHILELT_IMPL(int32_t) +SVWHILELT_IMPL(int64_t) + +#undef SVWHILELT_IMPL + +template +inline svbool_t svwhilelt(IndexType a, IndexType b) { + return svwhilelt_size(a, b); +} + +#define SVEXP_IMPL(vtype, postfix) \ + inline vtype svexp_z(svbool_t pg, const vtype &a) { \ + return svexp_##postfix##_z(pg, a); \ + } + +SVEXP_IMPL(svfloat32_t, f32) +SVEXP_IMPL(svfloat16_t, f16) + +#undef SVEXP_IMPL +} // namespace sve +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/sve/softmax_sve.cc b/lite/backends/arm/math/sve/softmax_sve.cc new file mode 100644 index 00000000000..8314b0391c6 --- /dev/null +++ b/lite/backends/arm/math/sve/softmax_sve.cc @@ -0,0 +1,328 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "lite/backends/arm/math/sve/softmax_sve.h" +#include +#include "lite/backends/arm/math/sve/funcs_sve.h" +#include "lite/core/parallel_defines.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +namespace sve { + +template +void softmax_basic_sve(const Dtype* din, + Dtype* dout, + const int axis_size, + const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + + LITE_PARALLEL_BEGIN(i, tid, compute_size) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + Dtype max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + Dtype sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + Dtype sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } + LITE_PARALLEL_END() +} + +template +void softmax_axis4_sve(const Dtype* din, + Dtype* dout, + const int axis_size, + const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + auto vone = svdup_n(static_cast(1)); + auto vinf = svdiv_z(vone, vsum); + int i = 0; + LITE_PARALLEL_COMMON_BEGIN(i, tid, compute_size, 0, svcnt()) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + svbool_t pg = svwhilelt(i, compute_size); + const Dtype* din_ptr0 = din + real_index; + const Dtype* din_ptr1 = din_ptr0 + inner_num; + const Dtype* din_ptr2 = din_ptr1 + inner_num; + const Dtype* din_ptr3 = din_ptr2 + inner_num; + auto vdata0 = svld1(pg, din_ptr0); + auto vdata1 = svld1(pg, din_ptr1); + auto vdata2 = svld1(pg, din_ptr2); + auto vdata3 = svld1(pg, din_ptr3); + Dtype* dout_ptr0 = dout + real_index; + Dtype* dout_ptr1 = dout_ptr0 + inner_num; + // get max + auto vmax0 = svmax_m(pg, vdata0, vdata1); + auto vmax1 = svmax_m(pg, vdata2, vdata3); + Dtype* dout_ptr2 = dout_ptr1 + inner_num; + Dtype* dout_ptr3 = dout_ptr2 + inner_num; + auto vmax = svmax_m(pg, vmax0, vmax1); + // sub, exp and sum + auto vsum0 = svexp_z(pg, svsub_z(pg, vdata0, vmax)); + auto vsum1 = svexp_z(pg, svsub_z(pg, vdata1, vmax)); + auto vsum2 = svexp_z(pg, svsub_z(pg, vdata2, vmax)); + auto vsum3 = svexp_z(pg, svsub_z(pg, vdata3, vmax)); + + auto vsum_0 = svadd_m(pg, vsum0, vsum1); + auto vsum_1 = svadd_m(pg, vsum2, vsum3); + auto vsum = svadd_m(pg, vsum_0, vsum_1); + auto vout0 = svmul_z(pg, vsum0, vinf); + auto vout1 = svmul_z(pg, vsum1, vinf); + auto vout2 = svmul_z(pg, vsum2, vinf); + auto vout3 = svmul_z(pg, vsum3, vinf); + svst1(pg, dout_ptr0, vout0); + svst1(pg, dout_ptr1, vout1); + svst1(pg, dout_ptr2, vout2); + svst1(pg, dout_ptr3, vout3); + } + LITE_PARALLEL_END() +} + +template +void softmax_inner1_sve(const Dtype* din, + Dtype* dout, + const int outer_size, + const int axis_size) { + int out_cnt = (outer_size >> 2) << 2; + auto vone = svdup_n(static_cast(1)); + const auto all_true_pg = svptrue(); + int i = 0; + LITE_PARALLEL_COMMON_BEGIN(i, tid, outer_size - 3, 0, 4) { + auto index = i * axis_size; + auto pg = svwhilelt(i, outer_size); + const Dtype* din_ptr0 = din + index; + const Dtype* din_ptr1 = din_ptr0 + axis_size; + const Dtype* din_ptr2 = din_ptr1 + axis_size; + const Dtype* din_ptr3 = din_ptr2 + axis_size; + Dtype* din_max_ptr0 = din_ptr0; + Dtype* din_max_ptr1 = din_ptr1; + Dtype* din_max_ptr2 = din_ptr2; + Dtype* din_max_ptr3 = din_ptr3; + int x = 0; + auto pg0 = svwhilelt(x, axis_size); + auto vec_max0 = svdup_n(support::cpp11::lowest()); + auto vec_max1 = svdup_n(support::cpp11::lowest()); + auto vec_max2 = svdup_n(support::cpp11::lowest()); + auto vec_max3 = svdup_n(support::cpp11::lowest()); + do { + auto vdata0 = svld1(pg, din_max_ptr0); + auto vdata1 = svld1(pg, din_max_ptr1); + auto vdata2 = svld1(pg, din_max_ptr2); + auto vdata3 = svld1(pg, din_max_ptr3); + // get max + auto vmax0 = svmax_m(pg, vec_max0, vdata0); + auto vmax1 = svmax_m(pg, vec_max1, vdata1); + auto vmax0 = svmax_m(pg, vec_max2, vdata2); + auto vmax1 = svmax_m(pg, vec_max3, vdata3); + din_max_ptr0 += svcnt(); + din_max_ptr1 += svcnt(); + din_max_ptr2 += svcnt(); + din_max_ptr3 += svcnt(); + x += svcnt(); + pg = svwhilelt(x, axis_size); + } while (svptest_any(all_true_pg, pg0)); + Dtype vmax_0 = svmaxv(vec_max0); + Dtype vmax_1 = svmaxv(vec_max1); + Dtype vmax_2 = svmaxv(vec_max2); + Dtype vmax_3 = svmaxv(vec_max3); + // sub, exp and sum + x = 0; + din_max_ptr0 = din_ptr0; + din_max_ptr1 = din_ptr1; + din_max_ptr2 = din_ptr2; + din_max_ptr3 = din_ptr3; + Dtype* dout_ptr0 = dout + index; + Dtype* dout_ptr1 = dout_ptr0 + axis_size; + Dtype* dout_ptr2 = dout_ptr1 + axis_size; + Dtype* dout_ptr3 = dout_ptr2 + axis_size; + auto vsum0 = svdup_n(static_cast(0)); + auto vsum1 = svdup_n(static_cast(0)); + auto vsum2 = svdup_n(static_cast(0)); + auto vsum3 = svdup_n(static_cast(0)); + auto vmax0 = svdup_n(vmax_0); + auto vmax1 = svdup_n(vmax_1); + auto vmax2 = svdup_n(vmax_2); + auto vmax3 = svdup_n(vmax_3); + for (int j = 0; j < axis_size; j += svcnt()) { + auto pg0 = svwhilelt(j, axis_size); + auto vsub_exp0 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr0), vmax0)); + auto vsub_exp1 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr1), vmax1)); + auto vsub_exp2 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr2), vmax2)); + auto vsub_exp3 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr3), vmax3)); + vsum0 = svadd_m(pg0, vsum0, vsub_exp0); + vsum1 = svadd_m(pg0, vsum1, vsub_exp1); + vsum2 = svadd_m(pg0, vsum2, vsub_exp2); + vsum3 = svadd_m(pg0, vsum3, vsub_exp3); + din_max_ptr0 += svcnt(); + din_max_ptr1 += svcnt(); + din_max_ptr2 += svcnt(); + din_max_ptr3 += svcnt(); + svst1(pg0, dout_ptr0, vsub_exp0); + svst1(pg0, dout_ptr1, vsub_exp1); + svst1(pg0, dout_ptr2, vsub_exp2); + svst1(pg0, dout_ptr3, vsub_exp3); + dout_ptr0 += svcnt(); + dout_ptr1 += svcnt(); + dout_ptr2 += svcnt(); + dout_ptr3 += svcnt(); + } + auto vsum_0 = svaddv(vsum0); + auto vsum_1 = svaddv(vsum1); + auto vsum_2 = svaddv(vsum2); + auto vsum_3 = svaddv(vsum3); + auto vinf0 = svmul_z(pg, svdup_n(vsum_0), vinf); + auto vinf1 = svmul_z(pg, svdup_n(vsum_1), vinf); + auto vinf2 = svmul_z(pg, svdup_n(vsum_2), vinf); + auto vinf3 = svmul_z(pg, svdup_n(vsum_3), vinf); + dout_ptr0 = dout + index; + dout_ptr1 = dout_ptr0 + axis_size; + dout_ptr2 = dout_ptr1 + axis_size; + dout_ptr3 = dout_ptr2 + axis_size; + for (int j = 0; j < axis_size; j += svcnt()) { + auto pg0 = svwhilelt(j, axis_size); + auto vsub_exp0 = svmul_z(pg0, svld1(pg0, dout_ptr0), vinf0); + auto vsub_exp1 = svmul_z(pg0, svld1(pg0, dout_ptr1), vinf1); + auto vsub_exp2 = svmul_z(pg0, svld1(pg0, dout_ptr2), vinf2); + auto vsub_exp3 = svmul_z(pg0, svld1(pg0, dout_ptr3), vinf3); + svst1(pg0, dout_ptr0, vsub_exp0); + svst1(pg0, dout_ptr1, vsub_exp1); + svst1(pg0, dout_ptr2, vsub_exp2); + svst1(pg0, dout_ptr3, vsub_exp3); + dout_ptr0 += svcnt(); + dout_ptr1 += svcnt(); + dout_ptr2 += svcnt(); + dout_ptr3 += svcnt(); + } + } + LITE_PARALLEL_END() + LITE_PARALLEL_COMMON_BEGIN(i, tid, outer_size, out_cnt, 1) { + auto index = i * axis_size; + auto pg = svwhilelt(i, outer_size); + const Dtype* din_ptr0 = din + index; + Dtype* din_max_ptr0 = din_ptr0; + int x = 0; + auto pg0 = svwhilelt(x, axis_size); + auto vec_max0 = svdup_n(support::cpp11::lowest()); + do { + auto vdata0 = svld1(pg, din_max_ptr0); + // get max + auto vmax0 = svmax_m(pg, vec_max0, vdata0); + din_max_ptr0 += svcnt(); + x += svcnt(); + pg = svwhilelt(x, axis_size); + } while (svptest_any(all_true_pg, pg0)); + Dtype vmax_0 = svmaxv(vec_max0); + // sub, exp and sum + x = 0; + din_max_ptr0 = din_ptr0; + Dtype* dout_ptr0 = dout + index; + auto vsum0 = svdup_n(static_cast(0)); + auto vmax0 = svdup_n(vmax_0); + for (int j = 0; j < axis_size; j += svcnt()) { + auto pg0 = svwhilelt(j, axis_size); + auto vsub_exp0 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr0), vmax0)); + + vsum0 = svadd_m(pg0, vsum0, vsub_exp0); + din_max_ptr0 += svcnt(); + svst1(pg0, dout_ptr0, vsub_exp0); + dout_ptr0 += svcnt(); + } + auto vsum_0 = svaddv(vsum0); + auto vinf0 = svmul_z(pg, svdup_n(vsum_0), vinf); + dout_ptr0 = dout + index; + for (int j = 0; j < axis_size; j += svcnt()) { + auto pg0 = svwhilelt(j, axis_size); + auto vsub_exp0 = svmul_z(pg0, svld1(pg0, dout_ptr0), vinf0); + svst1(pg0, dout_ptr0, vsub_exp0); + dout_ptr0 += svcnt(); + } + } + LITE_PARALLEL_END() +} + +template <> +softmax_basic_sve(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template <> +softmax_axis4_sve(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template <> +softmax_inner1_sve(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +#ifdef ENABLE_ARM_FP16 +template <> +softmax_basic_sve(const float16_t* din, + float16_t* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template <> +softmax_axis4_sve(const float16_t* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template <> +softmax_inner1_sve(const float16_t* din, + float16_t* dout, + const int axis_size, + const int inner_num, + const int outer_num); +#endif + +} // namespace sve +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/sve/softmax_sve.h b/lite/backends/arm/math/sve/softmax_sve.h new file mode 100644 index 00000000000..98329749ad2 --- /dev/null +++ b/lite/backends/arm/math/sve/softmax_sve.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +namespace sve { +template +void softmax_basic_sve(const Dtype* din, + Dtype* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template +void softmax_axis4_sve(const Dtype* din, + Dtype* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template +void softmax_inner1_sve(const Dtype* din, + Dtype* dout, + const int outer_size, + const int axis_size); + +} // namespace sve +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/core/context.h b/lite/core/context.h index cff75e317a7..d915fe8cb3c 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -370,6 +370,7 @@ class Context { bool has_dot() const { return DeviceInfo::Global().has_dot(); } bool has_fp16() const { return DeviceInfo::Global().has_fp16(); } bool has_a53_valid() const { return DeviceInfo::Global().set_a53_valid(); } + bool has_sve2() const { return DeviceInfo::Global().has_sve2(); } template T* workspace_data() { diff --git a/lite/core/device_info.cc b/lite/core/device_info.cc index 398d8080c21..f925246f741 100644 --- a/lite/core/device_info.cc +++ b/lite/core/device_info.cc @@ -80,6 +80,33 @@ namespace paddle { namespace lite { +// http://elixir.free-electrons.com/linux/latest/source/arch/arm64/include/uapi/asm/hwcap.h +#if defined(LITE_WITH_ANDROID) && defined(__aarch64__) +#include /* Get HWCAP bits from asm/hwcap.h */ +#include +#define AARCH64_HWCAP_SVE (1UL << 22) +#define AARCH64_HWCAP2_SVE2 (1UL << 1) +#define AARCH64_HWCAP2_SVEAES (1UL << 2) +#define AARCH64_HWCAP2_SVEPMULL (1UL << 3) +#define AARCH64_HWCAP2_SVEBITPERM (1UL << 4) +#define AARCH64_HWCAP2_SVESHA3 (1UL << 5) +#define AARCH64_HWCAP2_SVESM4 (1UL << 6) +#define AARCH64_HWCAP2_SVEI8MM (1UL << 9) +#define AARCH64_HWCAP2_SVEF32MM (1UL << 10) +#define AARCH64_HWCAP2_SVEF64MM (1UL << 11) +#define AARCH64_HWCAP2_SVEBF16 (1UL << 12) +#define AARCH64_HWCAP2_I8MM (1UL << 13) +#define AARCH64_HWCAP2_BF16 (1UL << 14) +#define AT_HWCAP 16 +#define AT_HWCAP2 26 + +bool check_sve2_valid() { + auto mask = static_cast(getauxval(AT_HWCAP2)); // Android API >= 18 + if (mask & AARCH64_HWCAP2_SVE2) return true; + return false; +} +#endif + #if ((defined LITE_WITH_ARM) || (defined LITE_WITH_MLU)) LITE_THREAD_LOCAL lite_api::PowerMode DeviceInfo::mode_; LITE_THREAD_LOCAL ARMArch DeviceInfo::arch_; @@ -225,6 +252,15 @@ void get_cpu_arch(std::vector* archs, const int cpu_num) { // 888 arch_type = kX1; break; + case 0xd46: + arch_type = kA510; + break; + case 0xd47: + arch_type = kA710; + break; + case 0xd48: + arch_type = kX2; + break; default: LOG(ERROR) << "Unknow cpu arch: " << arch_id; } @@ -1138,6 +1174,8 @@ void DeviceInfo::RequestPowerRandLowMode(int shift_num, int thread_num) { bool DeviceInfo::set_a53_valid() { return has_a53_valid_; } +bool DeviceInfo::has_sve2() { return has_sve2_; } + int DeviceInfo::Setup() { core_num_ = get_cpu_num(); mem_size_ = get_mem_size(); @@ -1192,6 +1230,12 @@ int DeviceInfo::Setup() { } else { has_a53_valid_ = true; } + // SVE2 + has_sve2_ = false; +#if defined(LITE_WITH_ANDROID) && defined(__aarch64__) + has_sve2_ = check_sve2_valid(); +#endif + // output info LOG(INFO) << "ARM multiprocessors name: " << dev_name_; LOG(INFO) << "ARM multiprocessors number: " << core_num_; @@ -1215,6 +1259,7 @@ int DeviceInfo::Setup() { LOG(INFO) << L3_cache_[i] / 1024 << " KB"; } LOG(INFO) << "Total memory: " << mem_size_ << "KB"; + LOG(INFO) << "SVE2 support: " << has_sve2_; // set default run mode SetRunMode(lite_api::PowerMode::LITE_POWER_NO_BIND, 1); // use single thread by default @@ -1528,5 +1573,23 @@ FMAType device_fma_level() { #endif +#if defined(LITE_WITH_ANDROID) && defined(__aarch64__) +#undef AARCH64_HWCAP_SVE +#undef AARCH64_HWCAP2_SVE2 +#undef AARCH64_HWCAP2_SVEAES +#undef AARCH64_HWCAP2_SVEPMULL +#undef AARCH64_HWCAP2_SVEBITPERM +#undef AARCH64_HWCAP2_SVESHA3 +#undef AARCH64_HWCAP2_SVESM4 +#undef AARCH64_HWCAP2_SVEI8MM +#undef AARCH64_HWCAP2_SVEF32MM +#undef AARCH64_HWCAP2_SVEF64MM +#undef AARCH64_HWCAP2_SVEBF16 +#undef AARCH64_HWCAP2_I8MM +#undef AARCH64_HWCAP2_BF16 +#undef AT_HWCAP +#undef AT_HWCAP2 +#endif + } // namespace lite } // namespace paddle diff --git a/lite/core/device_info.h b/lite/core/device_info.h index 3bf6583fd69..1a0992de335 100644 --- a/lite/core/device_info.h +++ b/lite/core/device_info.h @@ -39,10 +39,12 @@ using L3CacheSetMethod = lite_api::L3CacheSetMethod; typedef enum { kAPPLE = 0, kX1 = 1, + kX2 = 2, kA35 = 35, kA53 = 53, kA55 = 55, kA57 = 57, + kA510 = 60, kA72 = 72, kA73 = 73, kA75 = 75, @@ -52,6 +54,7 @@ typedef enum { kGold = 79, kGold_Prime = 80, kSilver = 81, + kA710 = 82, kARMArch_UNKOWN = -1 } ARMArch; @@ -69,6 +72,7 @@ class DeviceInfo { int Setup(); bool set_a53_valid(); + bool has_sve2(); void SetRunMode(lite_api::PowerMode mode, int thread_num); void SetCache(int l1size, int l2size, int l3size); @@ -151,6 +155,7 @@ class DeviceInfo { std::vector fp16_; std::vector dot_; bool has_a53_valid_; + bool has_sve2_; // LITE_POWER_HIGH stands for using big cores, // LITE_POWER_LOW stands for using small core, diff --git a/lite/kernels/arm/softmax_compute.cc b/lite/kernels/arm/softmax_compute.cc index a34df082483..28b267dc8df 100644 --- a/lite/kernels/arm/softmax_compute.cc +++ b/lite/kernels/arm/softmax_compute.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_ARM_FP16 #include "lite/backends/arm/math/fp16/funcs_fp16.h" #endif +#ifdef LITE_WITH_ARM8_SVE2 +#include "lite/backends/arm/math/sve/funcs_sve.h" +#endif namespace paddle { namespace lite { @@ -37,6 +40,20 @@ void SoftmaxCompute::Run() { int outer_num = x_dims.Slice(0, axis).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int axis_size = x_dims[axis]; + auto& ctx = this->ctx_->As(); +#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2) + if (ctx.has_sve2()) { + if (inner_num == 1) { + lite::arm::math::sve::softmax_inner1_sve(din, dout, outer_num, axis_size); + } else if (axis_size == 4) { + lite::arm::math::sve::softmax_axis4_sve(din, dout, outer_num, axis_size); + } else { + lite::arm::math::sve::softmax_baisc_sve(din, dout, outer_num, axis_size); + } + } + return; +#endif + if (inner_num == 1) { if (axis_size > 4) { lite::arm::math::softmax_inner1_large_axis( @@ -83,6 +100,19 @@ void SoftmaxCompute::Run() { int outer_num = x_dims.Slice(0, axis).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int axis_size = x_dims[axis]; +#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2) + if (ctx.has_sve2()) { + if (inner_num == 1) { + lite::arm::math::sve::softmax_inner1_sve(din, dout, outer_num, axis_size); + } else if (axis_size == 4) { + lite::arm::math::sve::softmax_axis4_sve(din, dout, outer_num, axis_size); + } else { + lite::arm::math::sve::softmax_baisc_sve(din, dout, outer_num, axis_size); + } + } + return; +#endif + if (inner_num == 1) { if (axis_size >= 8) { lite::arm::math::fp16::softmax_inner1_large_axis_fp16( diff --git a/lite/tools/build.sh b/lite/tools/build.sh index 6a6da9aa2d5..a58bf4bd7bd 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -33,6 +33,8 @@ WITH_PRECISION_PROFILE=OFF WITH_LTO=OFF BUILD_ARM82_FP16=OFF BUILD_ARM82_INT8_SDOT=OFF +# controls whether to support SVE2 instructions, default is OFF +WITH_ARM8_SVE2=OFF BUILD_NPU=OFF NPU_DDK_ROOT="$(pwd)/ai_ddk_lib/" # Download HiAI DDK from https://developer.huawei.com/consumer/cn/hiai/ BUILD_XPU=OFF @@ -225,6 +227,7 @@ function make_tiny_publish_so { -DLITE_WITH_RKNPU=$BUILD_RKNPU \ -DRKNPU_DDK_ROOT=$RKNPU_DDK_ROOT \ -DLITE_WITH_ARM82_FP16=$BUILD_ARM82_FP16 \ + -DLITE_WITH_ARM8_SVE2=$WITH_ARM8_SVE2 \ -DLITE_WITH_ARM82_INT8_SDOT=$BUILD_ARM82_INT8_SDOT \ -DLITE_THREAD_POOL=$BUILD_THREAD_POOL \ -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} @@ -338,6 +341,7 @@ function make_full_publish_so { -DLITE_WITH_APU=$BUILD_APU \ -DAPU_DDK_ROOT=$APU_DDK_ROOT \ -DLITE_WITH_ARM82_FP16=$BUILD_ARM82_FP16 \ + -DLITE_WITH_ARM8_SVE2=$WITH_ARM8_SVE2 \ -DLITE_WITH_ARM82_INT8_SDOT=$BUILD_ARM82_INT8_SDOT \ -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} @@ -406,6 +410,7 @@ function make_all_tests { -DLITE_WITH_RKNPU=$BUILD_RKNPU \ -DRKNPU_DDK_ROOT=$RKNPU_DDK_ROOT \ -DLITE_WITH_ARM82_FP16=$BUILD_ARM82_FP16 \ + -DLITE_WITH_ARM8_SVE2=$WITH_ARM8_SVE2 \ -DLITE_WITH_ARM82_INT8_SDOT=$BUILD_ARM82_INT8_SDOT \ -DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang} @@ -660,6 +665,8 @@ function print_usage { echo -e "--build_java: (OFF|ON); controls whether to publish java api lib (Only ANDROID is supported)" echo -e "--build_dir: directory for building" echo -e "--ios_deployment_target: (default: 9.0); Set the minimum compatible system version for ios deployment." + echo -e "| --with_arm8_sve2: (OFF|ON); controls whether to include SVE2 kernels, default is OFF |" + echo -e "| warning: when --with_arm8_sve2=ON, NDK version need >= r23, arch will be set as armv8. |" echo echo -e "argument choices:" echo -e "--arm_os:\t android|ios|ios64" @@ -785,6 +792,10 @@ function main { BUILD_ARM82_FP16="${i#*=}" shift ;; + --with_arm8_sve2=*) + WITH_ARM8_SVE2="${i#*=}" + shift + ;; --build_arm82_int8_sdot=*) BUILD_ARM82_INT8_SDOT="${i#*=}" shift From c13c8b699d929d60ea8aa46fb8c121d9543e82be Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Mon, 23 May 2022 16:29:01 +0800 Subject: [PATCH 02/11] update build dir --- lite/tools/build.sh | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/lite/tools/build.sh b/lite/tools/build.sh index a58bf4bd7bd..240a0cad815 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -199,6 +199,14 @@ function make_tiny_publish_so { set_android_api_level CMAKE_EXTRA_OPTIONS=${CMAKE_EXTRA_OPTIONS}" "${CMAKE_API_LEVEL_OPTIONS} fi + if [ "${BUILD_ARM82_FP16}" == "ON" ]; then + TOOLCHAIN=clang + build_dir=build_dir + ".armv82_fp16" + fi + if [ "${WITH_ARM8_SVE2}" == "ON" ]; then + TOOLCHAIN=clang + build_dir=build_dir + ".armv8_sve2" + fi cmake .. \ ${PYTHON_FLAGS} \ @@ -309,6 +317,14 @@ function make_full_publish_so { set_android_api_level CMAKE_EXTRA_OPTIONS=${CMAKE_EXTRA_OPTIONS}" "${CMAKE_API_LEVEL_OPTIONS} fi + if [ "${BUILD_ARM82_FP16}" == "ON" ]; then + TOOLCHAIN=clang + build_dir=build_dir + ".armv82_fp16" + fi + if [ "${WITH_ARM8_SVE2}" == "ON" ]; then + TOOLCHAIN=clang + build_dir=build_dir + ".armv8_sve2" + fi prepare_workspace $root_dir $build_directory cmake $root_dir \ @@ -384,6 +400,15 @@ function make_all_tests { if [ $4 == "benchmark" ]; then set_benchmark_options + build_dir=build_dir + ".benchmark" + fi + if [ "${BUILD_ARM82_FP16}" == "ON" ]; then + TOOLCHAIN=clang + build_dir=build_dir + ".armv82_fp16" + fi + if [ "${WITH_ARM8_SVE2}" == "ON" ]; then + TOOLCHAIN=clang + build_dir=build_dir + ".armv8_sve2" fi prepare_workspace $root_dir $build_directory From eb984ff153be9919c4a2e4b8e39af7258f6f070a Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Mon, 23 May 2022 20:59:43 +0800 Subject: [PATCH 03/11] fix build error --- cmake/configure.cmake | 2 +- cmake/postproject.cmake | 4 +- lite/backends/arm/math/sve/funcs_sve.h | 144 ++++++++++++++++++++- lite/backends/arm/math/sve/softmax_sve.cc | 141 ++++++++++---------- lite/kernels/arm/softmax_compute.cc | 17 ++- lite/tests/kernels/softmax_compute_test.cc | 11 +- lite/tools/build.sh | 64 ++++----- 7 files changed, 264 insertions(+), 119 deletions(-) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index f2f85bcb6df..8677d1c1b20 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -298,7 +298,7 @@ if (LITE_WITH_ARM82_FP16) endif(LITE_WITH_ARM82_FP16) if (LITE_WITH_ARM8_SVE2) - add_definitions("-DLITE_WITH_ARM82_FP16") + add_definitions("-DLITE_WITH_ARM8_SVE2c") endif(LITE_WITH_ARM8_SVE2) if (LITE_WITH_M1) diff --git a/cmake/postproject.cmake b/cmake/postproject.cmake index d098b2c893e..95423480085 100644 --- a/cmake/postproject.cmake +++ b/cmake/postproject.cmake @@ -42,13 +42,15 @@ if(ANDROID) endif() if(LITE_WITH_ARM8_SVE2) - if(${ANDROID_NDK_MAJOR}) + if ((ARM_TARGET_ARCH_ABI STREQUAL "armv8") and ${ANDROID_NDK_MAJOR}) if(${ANDROID_NDK_MAJOR} GREATER_EQUAL "23") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+sve2") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+sve2") else() message(FATAL_ERROR "NDK VERSION: ${ANDROID_NDK_MAJOR}, however it must be greater equal 23 when sve2 is ON") endif() + else() + message(FATAL_ERROR "The arm_abi is ${ARM_TARGET_ARCH_ABI}, the arm_abi must be armv8 when sve2 is ON") endif() endif() diff --git a/lite/backends/arm/math/sve/funcs_sve.h b/lite/backends/arm/math/sve/funcs_sve.h index e16cc69899f..e55d774e6b9 100644 --- a/lite/backends/arm/math/sve/funcs_sve.h +++ b/lite/backends/arm/math/sve/funcs_sve.h @@ -11,13 +11,40 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +/* + * The following function is base on + * https://github.com/ARM-software/ComputeLibrary/ + * + * Copyright (c) 2017-2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ #pragma once #include - +#include #include #include +#include #include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/sve/softmax_sve.h" @@ -25,7 +52,6 @@ #include "lite/backends/arm/math/fp16/funcs_fp16.h" typedef __fp16 float16_t; #endif -#include namespace paddle { namespace lite { @@ -102,7 +128,6 @@ SVDUP_N_IMPL(uint32_t, svuint32_t, u32) SVDUP_N_IMPL(uint64_t, svuint64_t, u64) SVDUP_N_IMPL(float16_t, svfloat16_t, f16) SVDUP_N_IMPL(float, svfloat32_t, f32) -SVDUP_N_IMPL(bfloat16_t, svbfloat16_t, bf16) #undef SVDUP_N_IMPL @@ -131,6 +156,119 @@ SVWHILELT_IMPL(int64_t) #undef SVWHILELT_IMPL +inline svfloat32_t svtaylor_poly_f32_z(svbool_t pg, + svfloat32_t x, + svfloat32_t coeff_1, + svfloat32_t coeff_2, + svfloat32_t coeff_3, + svfloat32_t coeff_4, + svfloat32_t coeff_5, + svfloat32_t coeff_6, + svfloat32_t coeff_7, + svfloat32_t coeff_8) { + const auto A = svmla_f32_z(pg, coeff_1, coeff_5, x); + const auto B = svmla_f32_z(pg, coeff_3, coeff_7, x); + const auto C = svmla_f32_z(pg, coeff_2, coeff_6, x); + const auto D = svmla_f32_z(pg, coeff_4, coeff_8, x); + const auto x2 = svmul_f32_z(pg, x, x); + const auto x4 = svmul_f32_z(pg, x2, x2); + const auto res = + svmla_f32_z(pg, svmla_f32_z(pg, A, B, x2), svmla_f32_z(pg, C, D, x2), x4); + return res; +} + +inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) { + const auto CONST_LN2 = svdup_n_f32(0.6931471805f); // ln(2) + const auto CONST_INV_LN2 = svdup_n_f32(1.4426950408f); // 1/ln(2) + const auto CONST_INF = svdup_n_f32(std::numeric_limits::infinity()); + const auto CONST_MAX_INPUT = svdup_n_f32(88.7f); + const auto CONST_0 = svdup_n_f32(0.f); + const auto CONST_NEGATIVE_126 = svdup_n_s32(-126); + + /** Exponent polynomial coefficients */ + const svfloat32_t exp_tab_1 = svdup_n_f32(1.f); + const svfloat32_t exp_tab_2 = svdup_n_f32(0.0416598916054f); + const svfloat32_t exp_tab_3 = svdup_n_f32(0.500000596046f); + const svfloat32_t exp_tab_4 = svdup_n_f32(0.0014122662833f); + const svfloat32_t exp_tab_5 = svdup_n_f32(1.00000011921f); + const svfloat32_t exp_tab_6 = svdup_n_f32(0.00833693705499f); + const svfloat32_t exp_tab_7 = svdup_n_f32(0.166665703058f); + const svfloat32_t exp_tab_8 = svdup_n_f32(0.000195780929062f); + + // Perform range reduction [-log(2),log(2)] + auto m = svcvt_s32_f32_z(pg, svmul_f32_z(pg, x, CONST_INV_LN2)); + auto val = svmls_f32_z(pg, x, svcvt_f32_s32_z(pg, m), CONST_LN2); + + // Polynomial Approximation + auto poly = svtaylor_poly_f32_z(pg, + val, + exp_tab_1, + exp_tab_2, + exp_tab_3, + exp_tab_4, + exp_tab_5, + exp_tab_6, + exp_tab_7, + exp_tab_8); + + // Reconstruct + poly = svreinterpret_f32_s32( + svqadd_s32(svreinterpret_s32_f32(poly), svlsl_n_s32_z(pg, m, 23))); + + // Handle underflow + svbool_t ltpg = svcmplt_s32(pg, m, CONST_NEGATIVE_126); + poly = svsel_f32(ltpg, CONST_0, poly); + + // Handle overflow + svbool_t gtpg = svcmpgt_f32(pg, x, CONST_MAX_INPUT); + poly = svsel_f32(gtpg, CONST_INF, poly); + + return poly; +} + +#ifdef ENABLE_ARM_FP16 +inline svfloat16_t svtaylor_poly_f16_z(svbool_t pg, + svfloat16_t x, + svfloat16_t coeff_1, + svfloat16_t coeff_2, + svfloat16_t coeff_3, + svfloat16_t coeff_4, + svfloat16_t coeff_5, + svfloat16_t coeff_6, + svfloat16_t coeff_7, + svfloat16_t coeff_8) { + const auto A = svmla_f16_z(pg, coeff_1, coeff_5, x); + const auto B = svmla_f16_z(pg, coeff_3, coeff_7, x); + const auto C = svmla_f16_z(pg, coeff_2, coeff_6, x); + const auto D = svmla_f16_z(pg, coeff_4, coeff_8, x); + const auto x2 = svmul_f16_z(pg, x, x); + const auto x4 = svmul_f16_z(pg, x2, x2); + const auto res = + svmla_f16_z(pg, svmla_f16_z(pg, A, B, x2), svmla_f16_z(pg, C, D, x2), x4); + return res; +} + +inline svfloat16_t svexp_f16_z(svbool_t pg, svfloat16_t x) { + auto bottom = svcvt_f32_z(pg, x); +#if defined(LITE_WITH_ARM8_SVE2) + auto top = svcvtlt_f32_x(pg, x); + auto pg_top = pg; +#else + auto pg_top = svptrue_b16(); + auto top = svcvt_f32_z( + pg_top, svreinterpret_f16(svrevh_z(svptrue_b16(), svreinterpret_u32(x)))); +#endif + bottom = svexp_f32_z(pg, bottom); + top = svexp_f32_z(pg_top, top); + +#if defined(LITE_WITH_ARM8_SVE2) + return svcvtnt_f16_m(svcvt_f16_z(pg, bottom), pg_top, top); +#else + return svtrn1(svcvt_f16_z(pg, bottom), svcvt_f16_z(pg_top, top)); +#endif +} +#endif + template inline svbool_t svwhilelt(IndexType a, IndexType b) { return svwhilelt_size(a, b); diff --git a/lite/backends/arm/math/sve/softmax_sve.cc b/lite/backends/arm/math/sve/softmax_sve.cc index 8314b0391c6..d9a94ab28e6 100644 --- a/lite/backends/arm/math/sve/softmax_sve.cc +++ b/lite/backends/arm/math/sve/softmax_sve.cc @@ -71,7 +71,6 @@ void softmax_axis4_sve(const Dtype* din, const int outer_num) { int compute_size = inner_num * outer_num; auto vone = svdup_n(static_cast(1)); - auto vinf = svdiv_z(vone, vsum); int i = 0; LITE_PARALLEL_COMMON_BEGIN(i, tid, compute_size, 0, svcnt()) { int idx_inner = i % inner_num; @@ -103,6 +102,7 @@ void softmax_axis4_sve(const Dtype* din, auto vsum_0 = svadd_m(pg, vsum0, vsum1); auto vsum_1 = svadd_m(pg, vsum2, vsum3); auto vsum = svadd_m(pg, vsum_0, vsum_1); + auto vinf = svdiv_z(pg, vone, vsum); auto vout0 = svmul_z(pg, vsum0, vinf); auto vout1 = svmul_z(pg, vsum1, vinf); auto vout2 = svmul_z(pg, vsum2, vinf); @@ -131,37 +131,37 @@ void softmax_inner1_sve(const Dtype* din, const Dtype* din_ptr1 = din_ptr0 + axis_size; const Dtype* din_ptr2 = din_ptr1 + axis_size; const Dtype* din_ptr3 = din_ptr2 + axis_size; - Dtype* din_max_ptr0 = din_ptr0; - Dtype* din_max_ptr1 = din_ptr1; - Dtype* din_max_ptr2 = din_ptr2; - Dtype* din_max_ptr3 = din_ptr3; + const Dtype* din_max_ptr0 = din_ptr0; + const Dtype* din_max_ptr1 = din_ptr1; + const Dtype* din_max_ptr2 = din_ptr2; + const Dtype* din_max_ptr3 = din_ptr3; int x = 0; auto pg0 = svwhilelt(x, axis_size); - auto vec_max0 = svdup_n(support::cpp11::lowest()); - auto vec_max1 = svdup_n(support::cpp11::lowest()); - auto vec_max2 = svdup_n(support::cpp11::lowest()); - auto vec_max3 = svdup_n(support::cpp11::lowest()); + auto vec_max0 = svdup_n(static_cast(-FLT_MAX)); + auto vec_max1 = svdup_n(static_cast(-FLT_MAX)); + auto vec_max2 = svdup_n(static_cast(-FLT_MAX)); + auto vec_max3 = svdup_n(static_cast(-FLT_MAX)); do { auto vdata0 = svld1(pg, din_max_ptr0); auto vdata1 = svld1(pg, din_max_ptr1); auto vdata2 = svld1(pg, din_max_ptr2); auto vdata3 = svld1(pg, din_max_ptr3); // get max - auto vmax0 = svmax_m(pg, vec_max0, vdata0); - auto vmax1 = svmax_m(pg, vec_max1, vdata1); - auto vmax0 = svmax_m(pg, vec_max2, vdata2); - auto vmax1 = svmax_m(pg, vec_max3, vdata3); + vec_max0 = svmax_m(pg, vec_max0, vdata0); + vec_max1 = svmax_m(pg, vec_max1, vdata1); + vec_max2 = svmax_m(pg, vec_max2, vdata2); + vec_max3 = svmax_m(pg, vec_max3, vdata3); din_max_ptr0 += svcnt(); din_max_ptr1 += svcnt(); din_max_ptr2 += svcnt(); din_max_ptr3 += svcnt(); x += svcnt(); - pg = svwhilelt(x, axis_size); + pg0 = svwhilelt(x, axis_size); } while (svptest_any(all_true_pg, pg0)); - Dtype vmax_0 = svmaxv(vec_max0); - Dtype vmax_1 = svmaxv(vec_max1); - Dtype vmax_2 = svmaxv(vec_max2); - Dtype vmax_3 = svmaxv(vec_max3); + Dtype vmax_0 = svmaxv(pg, vec_max0); + Dtype vmax_1 = svmaxv(pg, vec_max1); + Dtype vmax_2 = svmaxv(pg, vec_max2); + Dtype vmax_3 = svmaxv(pg, vec_max3); // sub, exp and sum x = 0; din_max_ptr0 = din_ptr0; @@ -182,10 +182,14 @@ void softmax_inner1_sve(const Dtype* din, auto vmax3 = svdup_n(vmax_3); for (int j = 0; j < axis_size; j += svcnt()) { auto pg0 = svwhilelt(j, axis_size); - auto vsub_exp0 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr0), vmax0)); - auto vsub_exp1 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr1), vmax1)); - auto vsub_exp2 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr2), vmax2)); - auto vsub_exp3 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr3), vmax3)); + auto vsub_exp0 = + svexp_z(pg0, svsub_z(pg0, svld1(pg0, din_max_ptr0), vmax0)); + auto vsub_exp1 = + svexp_z(pg0, svsub_z(pg0, svld1(pg0, din_max_ptr1), vmax1)); + auto vsub_exp2 = + svexp_z(pg0, svsub_z(pg0, svld1(pg0, din_max_ptr2), vmax2)); + auto vsub_exp3 = + svexp_z(pg0, svsub_z(pg0, svld1(pg0, din_max_ptr3), vmax3)); vsum0 = svadd_m(pg0, vsum0, vsub_exp0); vsum1 = svadd_m(pg0, vsum1, vsub_exp1); vsum2 = svadd_m(pg0, vsum2, vsub_exp2); @@ -203,14 +207,14 @@ void softmax_inner1_sve(const Dtype* din, dout_ptr2 += svcnt(); dout_ptr3 += svcnt(); } - auto vsum_0 = svaddv(vsum0); - auto vsum_1 = svaddv(vsum1); - auto vsum_2 = svaddv(vsum2); - auto vsum_3 = svaddv(vsum3); - auto vinf0 = svmul_z(pg, svdup_n(vsum_0), vinf); - auto vinf1 = svmul_z(pg, svdup_n(vsum_1), vinf); - auto vinf2 = svmul_z(pg, svdup_n(vsum_2), vinf); - auto vinf3 = svmul_z(pg, svdup_n(vsum_3), vinf); + auto vsum_0 = svaddv(pg, vsum0); + auto vsum_1 = svaddv(pg, vsum1); + auto vsum_2 = svaddv(pg, vsum2); + auto vsum_3 = svaddv(pg, vsum3); + auto vinf0 = svdiv_z(pg, vone, svdup_n(vsum_0)); + auto vinf1 = svdiv_z(pg, vone, svdup_n(vsum_1)); + auto vinf2 = svdiv_z(pg, vone, svdup_n(vsum_2)); + auto vinf3 = svdiv_z(pg, vone, svdup_n(vsum_3)); dout_ptr0 = dout + index; dout_ptr1 = dout_ptr0 + axis_size; dout_ptr2 = dout_ptr1 + axis_size; @@ -236,10 +240,10 @@ void softmax_inner1_sve(const Dtype* din, auto index = i * axis_size; auto pg = svwhilelt(i, outer_size); const Dtype* din_ptr0 = din + index; - Dtype* din_max_ptr0 = din_ptr0; + const Dtype* din_max_ptr0 = din_ptr0; int x = 0; auto pg0 = svwhilelt(x, axis_size); - auto vec_max0 = svdup_n(support::cpp11::lowest()); + auto vec_max0 = svdup_n(static_cast(-FLT_MAX)); do { auto vdata0 = svld1(pg, din_max_ptr0); // get max @@ -248,7 +252,7 @@ void softmax_inner1_sve(const Dtype* din, x += svcnt(); pg = svwhilelt(x, axis_size); } while (svptest_any(all_true_pg, pg0)); - Dtype vmax_0 = svmaxv(vec_max0); + Dtype vmax_0 = svmaxv(pg, vec_max0); // sub, exp and sum x = 0; din_max_ptr0 = din_ptr0; @@ -257,15 +261,16 @@ void softmax_inner1_sve(const Dtype* din, auto vmax0 = svdup_n(vmax_0); for (int j = 0; j < axis_size; j += svcnt()) { auto pg0 = svwhilelt(j, axis_size); - auto vsub_exp0 = svexp_z(svsub_z(pg0, svld1(pg0, din_max_ptr0), vmax0)); + auto vsub_exp0 = + svexp_z(pg0, svsub_z(pg0, svld1(pg0, din_max_ptr0), vmax0)); vsum0 = svadd_m(pg0, vsum0, vsub_exp0); din_max_ptr0 += svcnt(); svst1(pg0, dout_ptr0, vsub_exp0); dout_ptr0 += svcnt(); } - auto vsum_0 = svaddv(vsum0); - auto vinf0 = svmul_z(pg, svdup_n(vsum_0), vinf); + auto vsum_0 = svaddv(pg, vsum0); + auto vinf0 = svdiv_z(pg, vone, svdup_n(vsum_0)); dout_ptr0 = dout + index; for (int j = 0; j < axis_size; j += svcnt()) { auto pg0 = svwhilelt(j, axis_size); @@ -277,48 +282,40 @@ void softmax_inner1_sve(const Dtype* din, LITE_PARALLEL_END() } -template <> -softmax_basic_sve(const float* din, - float* dout, - const int axis_size, - const int inner_num, - const int outer_num); +template void softmax_basic_sve(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num); -template <> -softmax_axis4_sve(const float* din, - float* dout, - const int axis_size, - const int inner_num, - const int outer_num); +template void softmax_axis4_sve(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num); -template <> -softmax_inner1_sve(const float* din, - float* dout, - const int axis_size, - const int inner_num, - const int outer_num); +template void softmax_inner1_sve(const float* din, + float* dout, + const int outer_size, + const int axis_size); #ifdef ENABLE_ARM_FP16 -template <> -softmax_basic_sve(const float16_t* din, - float16_t* dout, - const int axis_size, - const int inner_num, - const int outer_num); +template void softmax_basic_sve(const float16_t* din, + float16_t* dout, + const int axis_size, + const int inner_num, + const int outer_num); -template <> -softmax_axis4_sve(const float16_t* din, - float* dout, - const int axis_size, - const int inner_num, - const int outer_num); +template void softmax_axis4_sve(const float16_t* din, + float16_t* dout, + const int axis_size, + const int inner_num, + const int outer_num); -template <> -softmax_inner1_sve(const float16_t* din, - float16_t* dout, - const int axis_size, - const int inner_num, - const int outer_num); +template void softmax_inner1_sve(const float16_t* din, + float16_t* dout, + const int outer_size, + const int axis_size); #endif } // namespace sve diff --git a/lite/kernels/arm/softmax_compute.cc b/lite/kernels/arm/softmax_compute.cc index 28b267dc8df..55bf1adeedc 100644 --- a/lite/kernels/arm/softmax_compute.cc +++ b/lite/kernels/arm/softmax_compute.cc @@ -41,14 +41,16 @@ void SoftmaxCompute::Run() { int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int axis_size = x_dims[axis]; auto& ctx = this->ctx_->As(); -#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2) +#ifdef LITE_WITH_ARM8_SVE2 if (ctx.has_sve2()) { if (inner_num == 1) { lite::arm::math::sve::softmax_inner1_sve(din, dout, outer_num, axis_size); } else if (axis_size == 4) { - lite::arm::math::sve::softmax_axis4_sve(din, dout, outer_num, axis_size); + lite::arm::math::sve::softmax_axis4_sve( + din, dout, axis_size, inner_num, outer_num); } else { - lite::arm::math::sve::softmax_baisc_sve(din, dout, outer_num, axis_size); + lite::arm::math::sve::softmax_basic_sve( + din, dout, axis_size, inner_num, outer_num); } } return; @@ -100,14 +102,17 @@ void SoftmaxCompute::Run() { int outer_num = x_dims.Slice(0, axis).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int axis_size = x_dims[axis]; -#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2) + auto& ctx = this->ctx_->As(); +#ifdef LITE_WITH_ARM8_SVE2 if (ctx.has_sve2()) { if (inner_num == 1) { lite::arm::math::sve::softmax_inner1_sve(din, dout, outer_num, axis_size); } else if (axis_size == 4) { - lite::arm::math::sve::softmax_axis4_sve(din, dout, outer_num, axis_size); + lite::arm::math::sve::softmax_axis4_sve( + din, dout, axis_size, inner_num, outer_num); } else { - lite::arm::math::sve::softmax_baisc_sve(din, dout, outer_num, axis_size); + lite::arm::math::sve::softmax_basic_sve( + din, dout, axis_size, inner_num, outer_num); } } return; diff --git a/lite/tests/kernels/softmax_compute_test.cc b/lite/tests/kernels/softmax_compute_test.cc index 59bbfe26a00..4a225281c72 100644 --- a/lite/tests/kernels/softmax_compute_test.cc +++ b/lite/tests/kernels/softmax_compute_test.cc @@ -100,8 +100,9 @@ TEST(Softmax, precision) { LOG(INFO) << "test softmax op"; float abs_error = 4e-5; Place place; + = #if defined(LITE_WITH_NNADAPTER) - place = TARGET(kNNAdapter); + place = TARGET(kNNAdapter); #if defined(NNADAPTER_WITH_HUAWEI_ASCEND_NPU) abs_error = 1e-2; #elif defined(NNADAPTER_WITH_CAMBRICON_MLU) @@ -114,16 +115,18 @@ TEST(Softmax, precision) { return; #endif #elif defined(LITE_WITH_NPU) - place = TARGET(kNPU); + place = TARGET(kNPU); abs_error = 4e-3; // Using fp16 in NPU // #elif defined(LITE_WITH_OPENCL) // place = Place(TARGET(kOpenCL), PRECISION(kFP16), // DATALAYOUT(kImageDefault)); // abs_error = 1e-2; // Using fp16 in OPENCL #elif defined(LITE_WITH_XPU) - place = TARGET(kXPU); + place = TARGET(kXPU); +#elif defined(LITE_WITH_ARM) + place = TARGET(kARM); #else - return; + return; #endif for (auto x_dims : diff --git a/lite/tools/build.sh b/lite/tools/build.sh index 240a0cad815..4e69083eeaa 100755 --- a/lite/tools/build.sh +++ b/lite/tools/build.sh @@ -183,6 +183,14 @@ function make_tiny_publish_so { if [ ! -d third-party ]; then git checkout third-party fi + if [ "${BUILD_ARM82_FP16}" == "ON" ]; then + TOOLCHAIN=clang + build_dir=$build_dir".armv82_fp16" + fi + if [ "${WITH_ARM8_SVE2}" == "ON" ]; then + TOOLCHAIN=clang + build_dir=$build_dir".armv8_sve2" + fi if [ -d $build_dir ] then @@ -199,14 +207,6 @@ function make_tiny_publish_so { set_android_api_level CMAKE_EXTRA_OPTIONS=${CMAKE_EXTRA_OPTIONS}" "${CMAKE_API_LEVEL_OPTIONS} fi - if [ "${BUILD_ARM82_FP16}" == "ON" ]; then - TOOLCHAIN=clang - build_dir=build_dir + ".armv82_fp16" - fi - if [ "${WITH_ARM8_SVE2}" == "ON" ]; then - TOOLCHAIN=clang - build_dir=build_dir + ".armv8_sve2" - fi cmake .. \ ${PYTHON_FLAGS} \ @@ -301,6 +301,14 @@ function make_full_publish_so { root_dir=$(pwd) build_directory=$BUILD_DIR/build.lite.${os}.${abi}.${lang} + if [ "${BUILD_ARM82_FP16}" == "ON" ]; then + TOOLCHAIN=clang + build_directory=$build_directory".armv82_fp16" + fi + if [ "${WITH_ARM8_SVE2}" == "ON" ]; then + TOOLCHAIN=clang + build_directory=$build_directory".armv8_sve2" + fi if [ -d $build_directory ] then @@ -317,14 +325,6 @@ function make_full_publish_so { set_android_api_level CMAKE_EXTRA_OPTIONS=${CMAKE_EXTRA_OPTIONS}" "${CMAKE_API_LEVEL_OPTIONS} fi - if [ "${BUILD_ARM82_FP16}" == "ON" ]; then - TOOLCHAIN=clang - build_dir=build_dir + ".armv82_fp16" - fi - if [ "${WITH_ARM8_SVE2}" == "ON" ]; then - TOOLCHAIN=clang - build_dir=build_dir + ".armv8_sve2" - fi prepare_workspace $root_dir $build_directory cmake $root_dir \ @@ -386,29 +386,29 @@ function make_all_tests { prepare_thirdparty root_dir=$(pwd) build_directory=$BUILD_DIR/build.lite.${os}.${abi}.${lang} - if [ -d $build_dir ] - then - rm -rf $build_dir - fi - mkdir -p $build_directory - - cd $build_directory - if [ ${os} == "android" ]; then - set_android_api_level - CMAKE_EXTRA_OPTIONS=${CMAKE_EXTRA_OPTIONS}" "${CMAKE_API_LEVEL_OPTIONS} - fi - if [ $4 == "benchmark" ]; then set_benchmark_options - build_dir=build_dir + ".benchmark" + build_directory=$build_directory".benchmark" fi if [ "${BUILD_ARM82_FP16}" == "ON" ]; then TOOLCHAIN=clang - build_dir=build_dir + ".armv82_fp16" + build_directory=$build_directory".armv82_fp16" fi if [ "${WITH_ARM8_SVE2}" == "ON" ]; then TOOLCHAIN=clang - build_dir=build_dir + ".armv8_sve2" + build_directory=$build_directory".armv8_sve2" + fi + + if [ -d $build_directory ] + then + rm -rf $build_directory + fi + mkdir -p $build_directory + + cd $build_directory + if [ ${os} == "android" ]; then + set_android_api_level + CMAKE_EXTRA_OPTIONS=${CMAKE_EXTRA_OPTIONS}" "${CMAKE_API_LEVEL_OPTIONS} fi prepare_workspace $root_dir $build_directory @@ -817,7 +817,7 @@ function main { BUILD_ARM82_FP16="${i#*=}" shift ;; - --with_arm8_sve2=*) + --build_arm8_sve2=*) WITH_ARM8_SVE2="${i#*=}" shift ;; From b8bd4a19e2aa552fef1331ab7032ec715c2c3d05 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Tue, 24 May 2022 11:58:37 +0800 Subject: [PATCH 04/11] aa --- lite/backends/arm/math/sve/softmax_sve.cc | 67 +++++++++++------------ 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/lite/backends/arm/math/sve/softmax_sve.cc b/lite/backends/arm/math/sve/softmax_sve.cc index d9a94ab28e6..362332a3e4f 100644 --- a/lite/backends/arm/math/sve/softmax_sve.cc +++ b/lite/backends/arm/math/sve/softmax_sve.cc @@ -122,11 +122,9 @@ void softmax_inner1_sve(const Dtype* din, const int axis_size) { int out_cnt = (outer_size >> 2) << 2; auto vone = svdup_n(static_cast(1)); - const auto all_true_pg = svptrue(); int i = 0; LITE_PARALLEL_COMMON_BEGIN(i, tid, outer_size - 3, 0, 4) { auto index = i * axis_size; - auto pg = svwhilelt(i, outer_size); const Dtype* din_ptr0 = din + index; const Dtype* din_ptr1 = din_ptr0 + axis_size; const Dtype* din_ptr2 = din_ptr1 + axis_size; @@ -141,27 +139,26 @@ void softmax_inner1_sve(const Dtype* din, auto vec_max1 = svdup_n(static_cast(-FLT_MAX)); auto vec_max2 = svdup_n(static_cast(-FLT_MAX)); auto vec_max3 = svdup_n(static_cast(-FLT_MAX)); - do { - auto vdata0 = svld1(pg, din_max_ptr0); - auto vdata1 = svld1(pg, din_max_ptr1); - auto vdata2 = svld1(pg, din_max_ptr2); - auto vdata3 = svld1(pg, din_max_ptr3); + for (int j = 0; j < axis_size; j++) { + pg0 = svwhilelt(j, axis_size); + auto vdata0 = svld1(pg0, din_max_ptr0); + auto vdata1 = svld1(pg0, din_max_ptr1); + auto vdata2 = svld1(pg0, din_max_ptr2); + auto vdata3 = svld1(pg0, din_max_ptr3); // get max - vec_max0 = svmax_m(pg, vec_max0, vdata0); - vec_max1 = svmax_m(pg, vec_max1, vdata1); - vec_max2 = svmax_m(pg, vec_max2, vdata2); - vec_max3 = svmax_m(pg, vec_max3, vdata3); + vec_max0 = svmax_m(pg0, vec_max0, vdata0); + vec_max1 = svmax_m(pg0, vec_max1, vdata1); + vec_max2 = svmax_m(pg0, vec_max2, vdata2); + vec_max3 = svmax_m(pg0, vec_max3, vdata3); din_max_ptr0 += svcnt(); din_max_ptr1 += svcnt(); din_max_ptr2 += svcnt(); din_max_ptr3 += svcnt(); - x += svcnt(); - pg0 = svwhilelt(x, axis_size); - } while (svptest_any(all_true_pg, pg0)); - Dtype vmax_0 = svmaxv(pg, vec_max0); - Dtype vmax_1 = svmaxv(pg, vec_max1); - Dtype vmax_2 = svmaxv(pg, vec_max2); - Dtype vmax_3 = svmaxv(pg, vec_max3); + } + Dtype vmax_0 = svmaxv(pg0, vec_max0); + Dtype vmax_1 = svmaxv(pg0, vec_max1); + Dtype vmax_2 = svmaxv(pg0, vec_max2); + Dtype vmax_3 = svmaxv(pg0, vec_max3); // sub, exp and sum x = 0; din_max_ptr0 = din_ptr0; @@ -207,14 +204,14 @@ void softmax_inner1_sve(const Dtype* din, dout_ptr2 += svcnt(); dout_ptr3 += svcnt(); } - auto vsum_0 = svaddv(pg, vsum0); - auto vsum_1 = svaddv(pg, vsum1); - auto vsum_2 = svaddv(pg, vsum2); - auto vsum_3 = svaddv(pg, vsum3); - auto vinf0 = svdiv_z(pg, vone, svdup_n(vsum_0)); - auto vinf1 = svdiv_z(pg, vone, svdup_n(vsum_1)); - auto vinf2 = svdiv_z(pg, vone, svdup_n(vsum_2)); - auto vinf3 = svdiv_z(pg, vone, svdup_n(vsum_3)); + auto vsum_0 = svaddv(pg0, vsum0); + auto vsum_1 = svaddv(pg0, vsum1); + auto vsum_2 = svaddv(pg0, vsum2); + auto vsum_3 = svaddv(pg0, vsum3); + auto vinf0 = svdiv_z(pg0, vone, svdup_n(vsum_0)); + auto vinf1 = svdiv_z(pg0, vone, svdup_n(vsum_1)); + auto vinf2 = svdiv_z(pg0, vone, svdup_n(vsum_2)); + auto vinf3 = svdiv_z(pg0, vone, svdup_n(vsum_3)); dout_ptr0 = dout + index; dout_ptr1 = dout_ptr0 + axis_size; dout_ptr2 = dout_ptr1 + axis_size; @@ -238,21 +235,19 @@ void softmax_inner1_sve(const Dtype* din, LITE_PARALLEL_END() LITE_PARALLEL_COMMON_BEGIN(i, tid, outer_size, out_cnt, 1) { auto index = i * axis_size; - auto pg = svwhilelt(i, outer_size); const Dtype* din_ptr0 = din + index; const Dtype* din_max_ptr0 = din_ptr0; int x = 0; auto pg0 = svwhilelt(x, axis_size); auto vec_max0 = svdup_n(static_cast(-FLT_MAX)); - do { - auto vdata0 = svld1(pg, din_max_ptr0); + for (int j = 0; j < axis_size; j += svcnt()) { + pg0 = svwhilelt(j, axis_size); + auto vdata0 = svld1(pg0, din_max_ptr0); // get max - auto vmax0 = svmax_m(pg, vec_max0, vdata0); + auto vmax0 = svmax_m(pg0, vec_max0, vdata0); din_max_ptr0 += svcnt(); - x += svcnt(); - pg = svwhilelt(x, axis_size); - } while (svptest_any(all_true_pg, pg0)); - Dtype vmax_0 = svmaxv(pg, vec_max0); + } + Dtype vmax_0 = svmaxv(pg0, vec_max0); // sub, exp and sum x = 0; din_max_ptr0 = din_ptr0; @@ -269,8 +264,8 @@ void softmax_inner1_sve(const Dtype* din, svst1(pg0, dout_ptr0, vsub_exp0); dout_ptr0 += svcnt(); } - auto vsum_0 = svaddv(pg, vsum0); - auto vinf0 = svdiv_z(pg, vone, svdup_n(vsum_0)); + auto vsum_0 = svaddv(pg0, vsum0); + auto vinf0 = svdiv_z(pg0, vone, svdup_n(vsum_0)); dout_ptr0 = dout + index; for (int j = 0; j < axis_size; j += svcnt()) { auto pg0 = svwhilelt(j, axis_size); From 9ac7f09f1a6b7f95e60e160f1e9eac6c9b8dc0bf Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Tue, 24 May 2022 14:26:21 +0800 Subject: [PATCH 05/11] fix run diff --- lite/backends/arm/math/sve/softmax_sve.cc | 34 ++++++++++++----------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/lite/backends/arm/math/sve/softmax_sve.cc b/lite/backends/arm/math/sve/softmax_sve.cc index 362332a3e4f..5c79b69c4dd 100644 --- a/lite/backends/arm/math/sve/softmax_sve.cc +++ b/lite/backends/arm/math/sve/softmax_sve.cc @@ -71,6 +71,7 @@ void softmax_axis4_sve(const Dtype* din, const int outer_num) { int compute_size = inner_num * outer_num; auto vone = svdup_n(static_cast(1)); + const auto all_true_pg = svptrue(); int i = 0; LITE_PARALLEL_COMMON_BEGIN(i, tid, compute_size, 0, svcnt()) { int idx_inner = i % inner_num; @@ -122,6 +123,7 @@ void softmax_inner1_sve(const Dtype* din, const int axis_size) { int out_cnt = (outer_size >> 2) << 2; auto vone = svdup_n(static_cast(1)); + const auto all_true_pg = svptrue(); int i = 0; LITE_PARALLEL_COMMON_BEGIN(i, tid, outer_size - 3, 0, 4) { auto index = i * axis_size; @@ -155,10 +157,10 @@ void softmax_inner1_sve(const Dtype* din, din_max_ptr2 += svcnt(); din_max_ptr3 += svcnt(); } - Dtype vmax_0 = svmaxv(pg0, vec_max0); - Dtype vmax_1 = svmaxv(pg0, vec_max1); - Dtype vmax_2 = svmaxv(pg0, vec_max2); - Dtype vmax_3 = svmaxv(pg0, vec_max3); + Dtype vmax_0 = svmaxv(all_true_pg, vec_max0); + Dtype vmax_1 = svmaxv(all_true_pg, vec_max1); + Dtype vmax_2 = svmaxv(all_true_pg, vec_max2); + Dtype vmax_3 = svmaxv(all_true_pg, vec_max3); // sub, exp and sum x = 0; din_max_ptr0 = din_ptr0; @@ -204,14 +206,14 @@ void softmax_inner1_sve(const Dtype* din, dout_ptr2 += svcnt(); dout_ptr3 += svcnt(); } - auto vsum_0 = svaddv(pg0, vsum0); - auto vsum_1 = svaddv(pg0, vsum1); - auto vsum_2 = svaddv(pg0, vsum2); - auto vsum_3 = svaddv(pg0, vsum3); - auto vinf0 = svdiv_z(pg0, vone, svdup_n(vsum_0)); - auto vinf1 = svdiv_z(pg0, vone, svdup_n(vsum_1)); - auto vinf2 = svdiv_z(pg0, vone, svdup_n(vsum_2)); - auto vinf3 = svdiv_z(pg0, vone, svdup_n(vsum_3)); + auto vsum_0 = svaddv(all_true_pg, vsum0); + auto vsum_1 = svaddv(all_true_pg, vsum1); + auto vsum_2 = svaddv(all_true_pg, vsum2); + auto vsum_3 = svaddv(all_true_pg, vsum3); + auto vinf0 = svdiv_z(all_true_pg, vone, svdup_n(vsum_0)); + auto vinf1 = svdiv_z(all_true_pg, vone, svdup_n(vsum_1)); + auto vinf2 = svdiv_z(all_true_pg, vone, svdup_n(vsum_2)); + auto vinf3 = svdiv_z(all_true_pg, vone, svdup_n(vsum_3)); dout_ptr0 = dout + index; dout_ptr1 = dout_ptr0 + axis_size; dout_ptr2 = dout_ptr1 + axis_size; @@ -244,10 +246,10 @@ void softmax_inner1_sve(const Dtype* din, pg0 = svwhilelt(j, axis_size); auto vdata0 = svld1(pg0, din_max_ptr0); // get max - auto vmax0 = svmax_m(pg0, vec_max0, vdata0); + vec_max0 = svmax_m(pg0, vec_max0, vdata0); din_max_ptr0 += svcnt(); } - Dtype vmax_0 = svmaxv(pg0, vec_max0); + Dtype vmax_0 = svmaxv(all_true_pg, vec_max0); // sub, exp and sum x = 0; din_max_ptr0 = din_ptr0; @@ -264,8 +266,8 @@ void softmax_inner1_sve(const Dtype* din, svst1(pg0, dout_ptr0, vsub_exp0); dout_ptr0 += svcnt(); } - auto vsum_0 = svaddv(pg0, vsum0); - auto vinf0 = svdiv_z(pg0, vone, svdup_n(vsum_0)); + auto vsum_0 = svaddv(all_true_pg, vsum0); + auto vinf0 = svdiv_z(all_true_pg, vone, svdup_n(vsum_0)); dout_ptr0 = dout + index; for (int j = 0; j < axis_size; j += svcnt()) { auto pg0 = svwhilelt(j, axis_size); From f3b3703a6072a653834c96fae489ada403354eac Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Tue, 24 May 2022 15:27:53 +0800 Subject: [PATCH 06/11] Update configure.cmake --- cmake/configure.cmake | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 6230d92327c..60a9a93b734 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -298,7 +298,7 @@ if (LITE_WITH_ARM82_FP16) endif(LITE_WITH_ARM82_FP16) if (LITE_WITH_ARM8_SVE2) - add_definitions("-DLITE_WITH_ARM8_SVE2c") + add_definitions("-DLITE_WITH_ARM8_SVE2") endif(LITE_WITH_ARM8_SVE2) if (LITE_WITH_M1) @@ -309,7 +309,3 @@ if (EMSCRIPTEN) add_compile_options("-pthread") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -pthread") endif() - -if (LITE_WITH_ARM8_SVE2) - add_definitions("-DLITE_WITH_ARM8_SVE2") -endif() From 5139785b2f9848801c5d3fe84bd3f56df28de7fb Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Tue, 24 May 2022 15:36:15 +0800 Subject: [PATCH 07/11] fix build error --- cmake/postproject.cmake | 4 +++- lite/core/device_info.cc | 27 --------------------------- 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/cmake/postproject.cmake b/cmake/postproject.cmake index 95423480085..eb427e89ed0 100644 --- a/cmake/postproject.cmake +++ b/cmake/postproject.cmake @@ -42,13 +42,15 @@ if(ANDROID) endif() if(LITE_WITH_ARM8_SVE2) - if ((ARM_TARGET_ARCH_ABI STREQUAL "armv8") and ${ANDROID_NDK_MAJOR}) + if ((ARM_TARGET_ARCH_ABI STREQUAL "armv8")) + if (${ANDROID_NDK_MAJOR}) if(${ANDROID_NDK_MAJOR} GREATER_EQUAL "23") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+sve2") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+sve2") else() message(FATAL_ERROR "NDK VERSION: ${ANDROID_NDK_MAJOR}, however it must be greater equal 23 when sve2 is ON") endif() + endif() else() message(FATAL_ERROR "The arm_abi is ${ARM_TARGET_ARCH_ABI}, the arm_abi must be armv8 when sve2 is ON") endif() diff --git a/lite/core/device_info.cc b/lite/core/device_info.cc index 349f5a5973a..5741e778209 100644 --- a/lite/core/device_info.cc +++ b/lite/core/device_info.cc @@ -106,33 +106,6 @@ bool check_sve2_valid() { } #endif -// http://elixir.free-electrons.com/linux/latest/source/arch/arm64/include/uapi/asm/hwcap.h -#if defined(LITE_WITH_ANDROID) && defined(__aarch64__) -#include /* Get HWCAP bits from asm/hwcap.h */ -#include -#define AARCH64_HWCAP_SVE (1UL << 22) -#define AARCH64_HWCAP2_SVE2 (1UL << 1) -#define AARCH64_HWCAP2_SVEAES (1UL << 2) -#define AARCH64_HWCAP2_SVEPMULL (1UL << 3) -#define AARCH64_HWCAP2_SVEBITPERM (1UL << 4) -#define AARCH64_HWCAP2_SVESHA3 (1UL << 5) -#define AARCH64_HWCAP2_SVESM4 (1UL << 6) -#define AARCH64_HWCAP2_SVEI8MM (1UL << 9) -#define AARCH64_HWCAP2_SVEF32MM (1UL << 10) -#define AARCH64_HWCAP2_SVEF64MM (1UL << 11) -#define AARCH64_HWCAP2_SVEBF16 (1UL << 12) -#define AARCH64_HWCAP2_I8MM (1UL << 13) -#define AARCH64_HWCAP2_BF16 (1UL << 14) -#define AT_HWCAP 16 -#define AT_HWCAP2 26 - -bool check_sve2_valid() { - auto mask = static_cast(getauxval(AT_HWCAP2)); // Android API >= 18 - if (mask & AARCH64_HWCAP2_SVE2) return true; - return false; -} -#endif - #if ((defined LITE_WITH_ARM) || (defined LITE_WITH_MLU)) LITE_THREAD_LOCAL lite_api::PowerMode DeviceInfo::mode_; LITE_THREAD_LOCAL ARMArch DeviceInfo::arch_; From e50b3cac7963fbafb67c2fb6e2690c9a7bcd34c2 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Tue, 24 May 2022 15:43:46 +0800 Subject: [PATCH 08/11] fix error --- lite/tests/kernels/softmax_compute_test.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/lite/tests/kernels/softmax_compute_test.cc b/lite/tests/kernels/softmax_compute_test.cc index 4a225281c72..e39ed2ce0db 100644 --- a/lite/tests/kernels/softmax_compute_test.cc +++ b/lite/tests/kernels/softmax_compute_test.cc @@ -100,9 +100,8 @@ TEST(Softmax, precision) { LOG(INFO) << "test softmax op"; float abs_error = 4e-5; Place place; - = #if defined(LITE_WITH_NNADAPTER) - place = TARGET(kNNAdapter); + place = TARGET(kNNAdapter); #if defined(NNADAPTER_WITH_HUAWEI_ASCEND_NPU) abs_error = 1e-2; #elif defined(NNADAPTER_WITH_CAMBRICON_MLU) @@ -115,18 +114,18 @@ TEST(Softmax, precision) { return; #endif #elif defined(LITE_WITH_NPU) - place = TARGET(kNPU); + place = TARGET(kNPU); abs_error = 4e-3; // Using fp16 in NPU // #elif defined(LITE_WITH_OPENCL) // place = Place(TARGET(kOpenCL), PRECISION(kFP16), // DATALAYOUT(kImageDefault)); // abs_error = 1e-2; // Using fp16 in OPENCL #elif defined(LITE_WITH_XPU) - place = TARGET(kXPU); + place = TARGET(kXPU); #elif defined(LITE_WITH_ARM) - place = TARGET(kARM); + place = TARGET(kARM); #else - return; + return; #endif for (auto x_dims : From b0db45f92644ec01f388b76b7d6e2991c72a8506 Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Tue, 24 May 2022 16:17:16 +0800 Subject: [PATCH 09/11] Update softmax_compute.cc --- lite/kernels/arm/softmax_compute.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lite/kernels/arm/softmax_compute.cc b/lite/kernels/arm/softmax_compute.cc index 55bf1adeedc..bf20241aa99 100644 --- a/lite/kernels/arm/softmax_compute.cc +++ b/lite/kernels/arm/softmax_compute.cc @@ -52,8 +52,8 @@ void SoftmaxCompute::Run() { lite::arm::math::sve::softmax_basic_sve( din, dout, axis_size, inner_num, outer_num); } + return; } - return; #endif if (inner_num == 1) { From dcc28265b650ff5f41ac8bd52129d13a6576abe4 Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Tue, 24 May 2022 20:13:38 +0800 Subject: [PATCH 10/11] Update softmax_compute.cc --- lite/kernels/arm/softmax_compute.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lite/kernels/arm/softmax_compute.cc b/lite/kernels/arm/softmax_compute.cc index bf20241aa99..e2bcc1ef0ef 100644 --- a/lite/kernels/arm/softmax_compute.cc +++ b/lite/kernels/arm/softmax_compute.cc @@ -114,8 +114,8 @@ void SoftmaxCompute::Run() { lite::arm::math::sve::softmax_basic_sve( din, dout, axis_size, inner_num, outer_num); } + return; } - return; #endif if (inner_num == 1) { From 9490bdf068eb10d7324a627ef8ea2cd73f242a0a Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Wed, 25 May 2022 11:17:11 +0800 Subject: [PATCH 11/11] fix ut error --- lite/kernels/arm/softmax_compute_test.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lite/kernels/arm/softmax_compute_test.cc b/lite/kernels/arm/softmax_compute_test.cc index 63c8dbb5dd3..48e8fe02ed3 100644 --- a/lite/kernels/arm/softmax_compute_test.cc +++ b/lite/kernels/arm/softmax_compute_test.cc @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include "lite/core/op_registry.h" @@ -108,6 +110,9 @@ TEST(softmax_arm, compute) { param.x = &x; param.axis = axis; param.output = &output; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + softmax.SetContext(std::move(ctx)); softmax.SetParam(param); softmax.Run(); param.output = &output_ref;