diff --git a/.github/workflows/linux-x64-cpu-gcc-sde.yml b/.github/workflows/linux-x64-cpu-gcc-sde.yml deleted file mode 100644 index eb680173743b..000000000000 --- a/.github/workflows/linux-x64-cpu-gcc-sde.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: linux-x64-cpu-gcc-sde -on: - push: - branches: [master] - paths: - - '.github/workflows/linux-x64-cpu-gcc-sde.yml' - - 'CMakeLists.txt' - - 'cmake/**' - - 'src/*' - - 'src/layer/*' - - 'src/layer/x86/**' - - 'tests/**' - - 'tools/**' - - '!tools/pnnx/**' - - 'examples/**' - pull_request: - branches: [master] - paths: - - '.github/workflows/linux-x64-cpu-gcc-sde.yml' - - 'CMakeLists.txt' - - 'cmake/**' - - 'src/*' - - 'src/layer/*' - - 'src/layer/x86/**' - - 'tests/**' - - 'tools/**' - - '!tools/pnnx/**' - - 'examples/**' -concurrency: - group: linux-x64-cpu-gcc-sde-${{ github.ref }} - cancel-in-progress: true -permissions: - contents: read - -jobs: - linux-gcc-sde: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v4 - - name: update - run: sudo apt-get update - - name: gcc12 - run: sudo apt-get install gcc-12 g++-12 - - name: Setup SDE binaries - uses: petarpetrovt/setup-sde@v2.4 - - name: build-avx512-spr - env: - CC: gcc-12 - CXX: g++-12 - run: | - mkdir build-avx512-spr && cd build-avx512-spr - cmake -DNCNN_BUILD_TESTS=ON .. - cmake --build . -j $(nproc) - - name: test-avx512-spr - run: | - cd build-avx512-spr - TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-spr;--" ctest --output-on-failure -j $(nproc) diff --git a/.github/workflows/linux-x64-sde.yml b/.github/workflows/linux-x64-sde.yml new file mode 100644 index 000000000000..1b3b43d0d446 --- /dev/null +++ b/.github/workflows/linux-x64-sde.yml @@ -0,0 +1,85 @@ +name: linux-x64-sde +on: + push: + branches: [master] + paths: + - '.github/workflows/linux-x64-sde.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'src/*' + - 'src/layer/*' + - 'src/layer/x86/**' + - 'tests/**' + - 'tools/**' + - '!tools/pnnx/**' + - 'examples/**' + pull_request: + branches: [master] + paths: + - '.github/workflows/linux-x64-sde.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'src/*' + - 'src/layer/*' + - 'src/layer/x86/**' + - 'tests/**' + - 'tools/**' + - '!tools/pnnx/**' + - 'examples/**' +concurrency: + group: linux-x64-sde-${{ github.ref }} + cancel-in-progress: true +permissions: + contents: read + +jobs: + gcc-sde: + runs-on: ubuntu-24.04 + steps: + - uses: actions/checkout@v4 + - name: update + run: sudo apt-get update + - name: gcc14 + run: sudo apt-get install gcc-14 g++-14 + - name: Setup SDE binaries + uses: petarpetrovt/setup-sde@v2.4 + - name: build + env: + CC: gcc-14 + CXX: g++-14 + run: | + mkdir build && cd build + cmake -DNCNN_BUILD_TESTS=ON .. + cmake --build . -j $(nproc) + - name: test-p4p + run: | + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-p4p;--" ctest --output-on-failure -j $(nproc) + - name: test-snb + run: | + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-snb;--" ctest --output-on-failure -j $(nproc) + - name: test-hsw + run: | + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-hsw;--" ctest --output-on-failure -j $(nproc) + - name: test-adl + run: | + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-adl;--" ctest --output-on-failure -j $(nproc) + - name: test-arl + run: | + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-arl;--" ctest --output-on-failure -j $(nproc) + - name: test-skx + run: | + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-skx;--" ctest --output-on-failure -j $(nproc) + - name: test-spr + run: | + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-spr;--" ctest --output-on-failure -j $(nproc) + - name: test-gnr + run: | + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-gnr;--" ctest --output-on-failure -j $(nproc) diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml index e846f2fd8460..eef0164e30c4 100644 --- a/.github/workflows/test-coverage.yml +++ b/.github/workflows/test-coverage.yml @@ -38,7 +38,7 @@ jobs: LD_LIBRARY_PATH: /data/action/install/lib64 run: | mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=debug -DNCNN_VULKAN=ON -DNCNN_COVERAGE=ON -DNCNN_RUNTIME_CPU=OFF -DNCNN_AVX2=ON -DNCNN_XOP=OFF -DNCNN_AVXVNNI=OFF -DNCNN_AVX512=ON -DNCNN_AVX512VNNI=ON -DNCNN_OPENMP=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=ON .. + cmake -DCMAKE_BUILD_TYPE=debug -DNCNN_VULKAN=ON -DNCNN_COVERAGE=ON -DNCNN_RUNTIME_CPU=OFF -DNCNN_AVX2=ON -DNCNN_XOP=OFF -DNCNN_AVXVNNI=OFF -DNCNN_AVXNECONVERT=OFF -DNCNN_AVX512=OFF -DNCNN_OPENMP=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=ON .. cmake --build . -j 4 - name: test env: @@ -54,61 +54,72 @@ jobs: lcov --list lcov.info - name: codecov - id: codecov - continue-on-error: true uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true plugins: noop files: build/lcov.info - - name: set the status - if: always() - run: | - if ${{ steps.codecov.outcome=='success' }}; then - echo fine - else - exit 1 - fi - linux-gcc-x64-avx512-spr: - runs-on: ubuntu-22.04 + linux-gcc-x64-sde: + name: linux-gcc-sde-${{ matrix.cpu }} + runs-on: ubuntu-24.04 + strategy: + fail-fast: false + matrix: + include: + - { cpu: hsw, AVX2: ON, AVXVNNI: OFF, AVXVNNIINT8: OFF, AVXNECONVERT: OFF, AVX512: OFF, AVX512VNNI: OFF, AVX512BF16: OFF, AVX512FP16: OFF } + - { cpu: adl, AVX2: ON, AVXVNNI: ON, AVXVNNIINT8: OFF, AVXNECONVERT: OFF, AVX512: OFF, AVX512VNNI: OFF, AVX512BF16: OFF, AVX512FP16: OFF } + - { cpu: arl, AVX2: ON, AVXVNNI: ON, AVXVNNIINT8: ON, AVXNECONVERT: ON, AVX512: OFF, AVX512VNNI: OFF, AVX512BF16: OFF, AVX512FP16: OFF } + - { cpu: spr, AVX2: ON, AVXVNNI: OFF, AVXVNNIINT8: OFF, AVXNECONVERT: OFF, AVX512: ON, AVX512VNNI: ON, AVX512BF16: ON, AVX512FP16: ON } steps: - uses: actions/checkout@v4 - name: update run: sudo apt-get update - - name: gcc12 - run: sudo apt-get install gcc-12 g++-12 + - name: gcc14 + run: sudo apt-get install gcc-14 g++-14 - name: lcov run: sudo apt-get install lcov - name: Setup SDE binaries uses: petarpetrovt/setup-sde@v2.4 - - name: build-avx512-spr + - name: build env: - CC: gcc-12 - CXX: g++-12 + CC: gcc-14 + CXX: g++-14 run: | - mkdir build-avx512-spr && cd build-avx512-spr - cmake -DCMAKE_BUILD_TYPE=debug -DNCNN_COVERAGE=ON -DNCNN_RUNTIME_CPU=OFF -DNCNN_AVX2=ON -DNCNN_AVX512=ON -DNCNN_AVX512VNNI=ON -DNCNN_AVX512BF16=ON -DNCNN_AVX512FP16=ON -DNCNN_XOP=OFF -DNCNN_OPENMP=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=ON .. - cmake --build . -j 2 - - name: test-avx512-spr + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=debug -DNCNN_COVERAGE=ON -DNCNN_RUNTIME_CPU=OFF \ + -DNCNN_AVX=ON \ + -DNCNN_F16C=ON \ + -DNCNN_XOP=OFF \ + -DNCNN_AVX2=${{ matrix.AVX2 }} \ + -DNCNN_AVXVNNI=${{ matrix.AVXVNNI }} \ + -DNCNN_AVXVNNIINT8=${{ matrix.AVXVNNIINT8 }} \ + -DNCNN_AVXNECONVERT=${{ matrix.AVXNECONVERT }} \ + -DNCNN_AVX512=${{ matrix.AVX512 }} \ + -DNCNN_AVX512VNNI=${{ matrix.AVX512VNNI }} \ + -DNCNN_AVX512BF16=${{ matrix.AVX512BF16 }} \ + -DNCNN_AVX512FP16=${{ matrix.AVX512FP16 }} \ + -DNCNN_OPENMP=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF -DNCNN_BUILD_TESTS=ON .. + cmake --build . -j $(nproc) + - name: test run: | - cd build-avx512-spr - TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-spr;--" ctest --output-on-failure -j 2 + cd build + TESTS_EXECUTABLE_LOADER=$SDE_PATH/sde64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-${{ matrix.cpu }};--" ctest --output-on-failure -j $(nproc) - name: lcov-collect run: | - cd build-avx512-spr - lcov --gcov-tool gcov-12 -d ./src -c -o lcov.info + cd build + lcov --gcov-tool gcov-14 -d ./src -c -o lcov.info lcov -r lcov.info '/usr/*' -o lcov.info - lcov -r lcov.info '*/build-avx512-spr/*' -o lcov.info + lcov -r lcov.info '*/build/*' -o lcov.info lcov --list lcov.info - - name: codecov-avx512-spr + - name: codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} disable_search: true plugins: noop - files: build-avx512-spr/lcov.info + files: build/lcov.info linux-gcc-riscv64-rvv: strategy: diff --git a/CMakeLists.txt b/CMakeLists.txt index 473440cc454a..5851552b2a5f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -508,7 +508,7 @@ else() check_cxx_compiler_flag("/arch:AVX512" NCNN_COMPILER_SUPPORT_X86_AVX512) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) @@ -545,7 +545,7 @@ else() check_cxx_compiler_flag("/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl" NCNN_COMPILER_SUPPORT_X86_AVX512) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) diff --git a/src/layer/gemm.cpp b/src/layer/gemm.cpp index 0ebe5974d0b7..b4a263f9b53e 100644 --- a/src/layer/gemm.cpp +++ b/src/layer/gemm.cpp @@ -220,9 +220,7 @@ static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A const int N = BT_int8.h; const int K = A_int8.w; // assert A_int8.w == BT_int8.w - // NCNN_LOGE("naive ds %f %f", A_int8_scales[0], BT_int8_scale); - - // #pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < M; i++) { const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; @@ -232,8 +230,6 @@ static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A const float descale = 1.f / (A_int8_scales[i] * BT_int8_scale); - // NCNN_LOGE("descale %f", descale); - for (int j = 0; j < N; j++) { const signed char* ptrBT = BT_int8.row(j); @@ -241,7 +237,6 @@ static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A int sum = 0; for (int k = 0; k < K; k++) { - // NCNN_LOGE("ptrA[%d] %d", k, ptrA[k]); sum += ptrA[k] * ptrBT[k]; } @@ -501,8 +496,6 @@ int Gemm::forward_int8(const std::vector& bottom_blobs, std::vector& t absmax = std::max(absmax, (float)fabs(ptr[k])); } - // NCNN_LOGE("A[%d] absmax %f", i, absmax); - float A_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; A_int8_scales[i] = A_int8_scale; @@ -534,8 +527,6 @@ int Gemm::forward_int8(const std::vector& bottom_blobs, std::vector& t } } - // NCNN_LOGE("B0 absmax %f", absmax); - B_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; for (int i = 0; i < B0_int8.h; i++) diff --git a/src/layer/x86/convolution_3x3_winograd_int8.h b/src/layer/x86/convolution_3x3_winograd_int8.h index acdb7cf83db2..82754e08b880 100644 --- a/src/layer/x86/convolution_3x3_winograd_int8.h +++ b/src/layer/x86/convolution_3x3_winograd_int8.h @@ -17,7 +17,7 @@ int conv3x3s1_winograd23_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, int conv3x3s1_winograd43_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVX512VNNI__ int conv3x3s1_winograd23_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); int conv3x3s1_winograd43_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); #endif @@ -705,41 +705,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); - _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); - _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); - _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); - _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); - _sum8 = _mm512_dpwssd_epi32(_sum8, _pA2, _pB0); - _sum9 = _mm512_dpwssd_epi32(_sum9, _pA2, _pB1); - _suma = _mm512_dpwssd_epi32(_suma, _pA2, _pB2); - _sumb = _mm512_dpwssd_epi32(_sumb, _pA2, _pB3); - _sumc = _mm512_dpwssd_epi32(_sumc, _pA3, _pB0); - _sumd = _mm512_dpwssd_epi32(_sumd, _pA3, _pB1); - _sume = _mm512_dpwssd_epi32(_sume, _pA3, _pB2); - _sumf = _mm512_dpwssd_epi32(_sumf, _pA3, _pB3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); - _sum8 = _mm512_add_epi32(_sum8, _mm512_madd_epi16(_pA2, _pB0)); - _sum9 = _mm512_add_epi32(_sum9, _mm512_madd_epi16(_pA2, _pB1)); - _suma = _mm512_add_epi32(_suma, _mm512_madd_epi16(_pA2, _pB2)); - _sumb = _mm512_add_epi32(_sumb, _mm512_madd_epi16(_pA2, _pB3)); - _sumc = _mm512_add_epi32(_sumc, _mm512_madd_epi16(_pA3, _pB0)); - _sumd = _mm512_add_epi32(_sumd, _mm512_madd_epi16(_pA3, _pB1)); - _sume = _mm512_add_epi32(_sume, _mm512_madd_epi16(_pA3, _pB2)); - _sumf = _mm512_add_epi32(_sumf, _mm512_madd_epi16(_pA3, _pB3)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA1, _pB3); + _sum8 = _mm512_comp_dpwssd_epi32(_sum8, _pA2, _pB0); + _sum9 = _mm512_comp_dpwssd_epi32(_sum9, _pA2, _pB1); + _suma = _mm512_comp_dpwssd_epi32(_suma, _pA2, _pB2); + _sumb = _mm512_comp_dpwssd_epi32(_sumb, _pA2, _pB3); + _sumc = _mm512_comp_dpwssd_epi32(_sumc, _pA3, _pB0); + _sumd = _mm512_comp_dpwssd_epi32(_sumd, _pA3, _pB1); + _sume = _mm512_comp_dpwssd_epi32(_sume, _pA3, _pB2); + _sumf = _mm512_comp_dpwssd_epi32(_sumf, _pA3, _pB3); pA += 32; pB += 32; @@ -984,25 +965,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); - _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); - _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); - _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); - _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA1, _pB3); pA += 32; pB += 16; @@ -1150,17 +1120,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 32; pB += 8; @@ -1244,13 +1207,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pB0 = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pB)[0])); __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA, _pB1); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA, _pB1); pA += 32; pB += 4; @@ -1312,11 +1270,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); __m512i _pB = _mm512_set1_epi32(((const int*)pB)[0]); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA, _pB); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA, _pB)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA, _pB); pA += 32; pB += 2; @@ -1396,25 +1350,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA00, _pB2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA00, _pB3); - _sum4 = _mm512_dpwssd_epi32(_sum4, _pA11, _pB0); - _sum5 = _mm512_dpwssd_epi32(_sum5, _pA11, _pB1); - _sum6 = _mm512_dpwssd_epi32(_sum6, _pA11, _pB2); - _sum7 = _mm512_dpwssd_epi32(_sum7, _pA11, _pB3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA00, _pB2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA00, _pB3)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA11, _pB0)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA11, _pB1)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA11, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA11, _pB3)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA00, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA00, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA00, _pB2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA00, _pB3); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA11, _pB0); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA11, _pB1); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA11, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA11, _pB3); pA += 16; pB += 32; @@ -1602,42 +1545,24 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); __m512i _pB23 = _mm512_shuffle_epi32(_pB01, _MM_PERM_BADC); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB01); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB23); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA11, _pB01); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA11, _pB23); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB01)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB23)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB01)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB23)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA00, _pB01); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA00, _pB23); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA11, _pB01); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA11, _pB23); #else // __AVX512F__ __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); -#if __AVXVNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm256_dpwssd_epi32(_sum2, _pA0, _pB2); - _sum3 = _mm256_dpwssd_epi32(_sum3, _pA0, _pB3); - _sum4 = _mm256_dpwssd_epi32(_sum4, _pA1, _pB0); - _sum5 = _mm256_dpwssd_epi32(_sum5, _pA1, _pB1); - _sum6 = _mm256_dpwssd_epi32(_sum6, _pA1, _pB2); - _sum7 = _mm256_dpwssd_epi32(_sum7, _pA1, _pB3); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA0, _pB2)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA0, _pB3)); - _sum4 = _mm256_add_epi32(_sum4, _mm256_madd_epi16(_pA1, _pB0)); - _sum5 = _mm256_add_epi32(_sum5, _mm256_madd_epi16(_pA1, _pB1)); - _sum6 = _mm256_add_epi32(_sum6, _mm256_madd_epi16(_pA1, _pB2)); - _sum7 = _mm256_add_epi32(_sum7, _mm256_madd_epi16(_pA1, _pB3)); -#endif // __AVXVNNI__ + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm256_comp_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm256_comp_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm256_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm256_comp_dpwssd_epi32(_sum7, _pA1, _pB3); #endif // __AVX512F__ pA += 16; @@ -1854,17 +1779,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); - _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 16; pB += 8; @@ -1948,13 +1866,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA, _pB1); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA, _pB1); pA += 16; pB += 4; @@ -2081,17 +1994,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 8; pB += 32; @@ -2231,18 +2137,11 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); - _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); -#endif -#else // __AVX2__ + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA1, _pB1); +#else // __AVX2__ __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); @@ -2250,25 +2149,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); - _sum2 = _mm_maddd_epi16(_pA0, _pB2, _sum2); - _sum3 = _mm_maddd_epi16(_pA0, _pB3, _sum3); - _sum4 = _mm_maddd_epi16(_pA1, _pB0, _sum4); - _sum5 = _mm_maddd_epi16(_pA1, _pB1, _sum5); - _sum6 = _mm_maddd_epi16(_pA1, _pB2, _sum6); - _sum7 = _mm_maddd_epi16(_pA1, _pB3, _sum7); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA0, _pB2)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA0, _pB3)); - _sum4 = _mm_add_epi32(_sum4, _mm_madd_epi16(_pA1, _pB0)); - _sum5 = _mm_add_epi32(_sum5, _mm_madd_epi16(_pA1, _pB1)); - _sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2)); - _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm_comp_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm_comp_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm_comp_dpwssd_epi32(_sum7, _pA1, _pB3); #endif // __AVX2__ pA += 8; @@ -2473,17 +2361,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); - _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); - _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 8; pB += 8; @@ -2582,13 +2463,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); pA += 8; pB += 4; @@ -2660,11 +2536,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pA = _mm_loadu_si128((const __m128i*)pA); __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB, _sum0); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB); pA += 8; pB += 2; @@ -2729,13 +2601,9 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pA0 = _mm512_set1_epi32(((const int*)pA)[0]); __m512i _pA1 = _mm512_set1_epi32(((const int*)pA)[1]); __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA1, _pB0); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA1, _pB0)); -#endif // __AVX512VNNI__ + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA1, _pB0); pA += 4; pB += 32; @@ -3091,11 +2959,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pA0 = _mm512_set1_epi32(((const int*)pA)[0]); __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); pA += 2; pB += 32; @@ -4433,7 +4297,7 @@ static int conv3x3s1_winograd23_int8(const Mat& bottom_blob, Mat& top_blob, cons } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { return conv3x3s1_winograd23_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); @@ -6265,7 +6129,7 @@ static int conv3x3s1_winograd43_int8(const Mat& bottom_blob, Mat& top_blob, cons } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { return conv3x3s1_winograd43_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); diff --git a/src/layer/x86/convolution_im2col_gemm_int8.h b/src/layer/x86/convolution_im2col_gemm_int8.h index 6c06a7edce6c..c0da248bd213 100644 --- a/src/layer/x86/convolution_im2col_gemm_int8.h +++ b/src/layer/x86/convolution_im2col_gemm_int8.h @@ -16,7 +16,7 @@ int convolution_im2col_gemm_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVX512VNNI__ int convolution_im2col_gemm_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt); #endif @@ -738,41 +738,22 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); - _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); - _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); - _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); - _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); - _sum8 = _mm512_dpwssd_epi32(_sum8, _pA2, _pB0); - _sum9 = _mm512_dpwssd_epi32(_sum9, _pA2, _pB1); - _suma = _mm512_dpwssd_epi32(_suma, _pA2, _pB2); - _sumb = _mm512_dpwssd_epi32(_sumb, _pA2, _pB3); - _sumc = _mm512_dpwssd_epi32(_sumc, _pA3, _pB0); - _sumd = _mm512_dpwssd_epi32(_sumd, _pA3, _pB1); - _sume = _mm512_dpwssd_epi32(_sume, _pA3, _pB2); - _sumf = _mm512_dpwssd_epi32(_sumf, _pA3, _pB3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); - _sum8 = _mm512_add_epi32(_sum8, _mm512_madd_epi16(_pA2, _pB0)); - _sum9 = _mm512_add_epi32(_sum9, _mm512_madd_epi16(_pA2, _pB1)); - _suma = _mm512_add_epi32(_suma, _mm512_madd_epi16(_pA2, _pB2)); - _sumb = _mm512_add_epi32(_sumb, _mm512_madd_epi16(_pA2, _pB3)); - _sumc = _mm512_add_epi32(_sumc, _mm512_madd_epi16(_pA3, _pB0)); - _sumd = _mm512_add_epi32(_sumd, _mm512_madd_epi16(_pA3, _pB1)); - _sume = _mm512_add_epi32(_sume, _mm512_madd_epi16(_pA3, _pB2)); - _sumf = _mm512_add_epi32(_sumf, _mm512_madd_epi16(_pA3, _pB3)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA1, _pB3); + _sum8 = _mm512_comp_dpwssd_epi32(_sum8, _pA2, _pB0); + _sum9 = _mm512_comp_dpwssd_epi32(_sum9, _pA2, _pB1); + _suma = _mm512_comp_dpwssd_epi32(_suma, _pA2, _pB2); + _sumb = _mm512_comp_dpwssd_epi32(_sumb, _pA2, _pB3); + _sumc = _mm512_comp_dpwssd_epi32(_sumc, _pA3, _pB0); + _sumd = _mm512_comp_dpwssd_epi32(_sumd, _pA3, _pB1); + _sume = _mm512_comp_dpwssd_epi32(_sume, _pA3, _pB2); + _sumf = _mm512_comp_dpwssd_epi32(_sumf, _pA3, _pB3); pA += 32; pB += 32; @@ -1480,25 +1461,14 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); - _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); - _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); - _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); - _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA1, _pB3); pA += 32; pB += 16; @@ -1914,17 +1884,10 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 1230 1230 1230 1230 __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 32; pB += 8; @@ -2164,13 +2127,8 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 1010 1010 1010 1010 __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); pA += 32; pB += 4; @@ -2286,11 +2244,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 0xxx0xxx0xxx0xxx -> 00000000... __m512i _pB0 = _mm512_shuffle_epi32(_pBBBB, _MM_PERM_AAAA); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); pA += 32; pB += 2; @@ -2417,25 +2371,14 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA00, _pB2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA00, _pB3); - _sum4 = _mm512_dpwssd_epi32(_sum4, _pA11, _pB0); - _sum5 = _mm512_dpwssd_epi32(_sum5, _pA11, _pB1); - _sum6 = _mm512_dpwssd_epi32(_sum6, _pA11, _pB2); - _sum7 = _mm512_dpwssd_epi32(_sum7, _pA11, _pB3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA00, _pB2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA00, _pB3)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA11, _pB0)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA11, _pB1)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA11, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA11, _pB3)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA00, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA00, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA00, _pB2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA00, _pB3); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA11, _pB0); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA11, _pB1); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA11, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA11, _pB3); pA += 16; pB += 32; @@ -2802,17 +2745,10 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); __m512i _pB23 = _mm512_shuffle_epi32(_pB01, _MM_PERM_BADC); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB01); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB23); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA11, _pB01); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA11, _pB23); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB01)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB23)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB01)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB23)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA00, _pB01); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA00, _pB23); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA11, _pB01); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA11, _pB23); #else // __AVX512F__ // 0123 4567 @@ -2827,25 +2763,14 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); -#if __AVXVNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm256_dpwssd_epi32(_sum2, _pA0, _pB2); - _sum3 = _mm256_dpwssd_epi32(_sum3, _pA0, _pB3); - _sum4 = _mm256_dpwssd_epi32(_sum4, _pA1, _pB0); - _sum5 = _mm256_dpwssd_epi32(_sum5, _pA1, _pB1); - _sum6 = _mm256_dpwssd_epi32(_sum6, _pA1, _pB2); - _sum7 = _mm256_dpwssd_epi32(_sum7, _pA1, _pB3); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA0, _pB2)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA0, _pB3)); - _sum4 = _mm256_add_epi32(_sum4, _mm256_madd_epi16(_pA1, _pB0)); - _sum5 = _mm256_add_epi32(_sum5, _mm256_madd_epi16(_pA1, _pB1)); - _sum6 = _mm256_add_epi32(_sum6, _mm256_madd_epi16(_pA1, _pB2)); - _sum7 = _mm256_add_epi32(_sum7, _mm256_madd_epi16(_pA1, _pB3)); -#endif // __AVXVNNI__ + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm256_comp_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm256_comp_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm256_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm256_comp_dpwssd_epi32(_sum7, _pA1, _pB3); #endif // __AVX512F__ pA += 16; @@ -3315,17 +3240,10 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 1230 1230 __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); - _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 16; pB += 8; @@ -3517,13 +3435,8 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 1010 1010 __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); pA += 16; pB += 4; @@ -3653,11 +3566,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 0xxx0xxx -> 00000000 11111111 __m256i _pB0 = _mm256_shuffle_epi32(_pBB, _MM_SHUFFLE(0, 0, 0, 0)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); pA += 16; pB += 2; @@ -3773,17 +3682,10 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 1230 5674 9ab8 defc __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); - _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 8; pB += 32; @@ -3983,17 +3885,10 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 1230 5674 __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); - _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); - _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA1, _pB1); #else // __AVX2__ #if __SSE4_1__ _pA = _mm_cvtepi8_epi16(_pA); @@ -4018,25 +3913,14 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m128i _pB2 = _mm_shuffle_epi32(_pBl, _MM_SHUFFLE(0, 3, 2, 1)); __m128i _pB3 = _mm_shuffle_epi32(_pBh, _MM_SHUFFLE(0, 3, 2, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); - _sum2 = _mm_maddd_epi16(_pA0, _pB2, _sum2); - _sum3 = _mm_maddd_epi16(_pA0, _pB3, _sum3); - _sum4 = _mm_maddd_epi16(_pA1, _pB0, _sum4); - _sum5 = _mm_maddd_epi16(_pA1, _pB1, _sum5); - _sum6 = _mm_maddd_epi16(_pA1, _pB2, _sum6); - _sum7 = _mm_maddd_epi16(_pA1, _pB3, _sum7); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA0, _pB2)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA0, _pB3)); - _sum4 = _mm_add_epi32(_sum4, _mm_madd_epi16(_pA1, _pB0)); - _sum5 = _mm_add_epi32(_sum5, _mm_madd_epi16(_pA1, _pB1)); - _sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2)); - _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm_comp_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm_comp_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm_comp_dpwssd_epi32(_sum7, _pA1, _pB3); #endif // __AVX2__ pA += 8; @@ -4381,17 +4265,10 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m128i _pB0 = _pB; __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(0, 3, 2, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); - _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); - _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 8; pB += 8; @@ -4570,13 +4447,8 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m128i _pB0 = _pB; __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(2, 3, 0, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); pA += 8; pB += 4; @@ -4707,11 +4579,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); #endif -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB, _sum0); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB); pA += 8; pB += 2; @@ -4821,13 +4689,8 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 1230 5674 9ab8 defc __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); pA += 4; pB += 32; @@ -4942,14 +4805,9 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 1230 5674 __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); -#if __AVX512VNNI__ || __AVXVNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); -#endif // __AVX512VNNI__ || __AVXVNNI__ -#else // __AVX2__ + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); +#else // __AVX2__ #if __SSE4_1__ _pA = _mm_cvtepi8_epi16(_pA); #else @@ -4968,17 +4826,10 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 0123 // 4567 -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); - _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); - _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); #endif // __AVX2__ pA += 4; @@ -5158,13 +5009,8 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m128i _pB0 = _pB; __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(0, 3, 2, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); pA += 4; pB += 8; @@ -5387,11 +5233,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); pA += 2; pB += 32; @@ -5466,12 +5308,8 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); -#if __AVX512VNNI__ || __AVXVNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); -#endif // __AVX512VNNI__ || __AVXVNNI__ -#else // __AVX2__ + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); +#else // __AVX2__ #if __SSE4_1__ _pA = _mm_cvtepi8_epi16(_pA); #else @@ -5482,13 +5320,8 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m128i _pB0 = _mm_unpacklo_epi8(_pB, _extpB); __m128i _pB1 = _mm_unpackhi_epi8(_pB, _extpB); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); #endif // __AVX2__ pA += 2; @@ -5580,11 +5413,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // 0xxx -> 0000 __m128i _pA0 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(0, 0, 0, 0)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB, _sum0); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB); pA += 2; pB += 8; @@ -7561,7 +7390,7 @@ static int convolution_im2col_gemm_int8(const Mat& bottom_blob, Mat& top_blob, c } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { return convolution_im2col_gemm_int8_avxvnni(bottom_blob, top_blob, AT, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, nT, opt); diff --git a/src/layer/x86/convolution_packed_int8.h b/src/layer/x86/convolution_packed_int8.h index 8a1659565f54..cc89769a2862 100644 --- a/src/layer/x86/convolution_packed_int8.h +++ b/src/layer/x86/convolution_packed_int8.h @@ -1056,73 +1056,38 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrrr30 = _mm512_broadcast_i32x4(_mm256_extracti128_si256(_rr3, 0)); __m512i _rrrr31 = _mm512_broadcast_i32x4(_mm256_extracti128_si256(_rr3, 1)); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_AAAA), _w0); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr20, _MM_PERM_AAAA), _w0); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr30, _MM_PERM_AAAA), _w0); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_BBBB), _w1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr20, _MM_PERM_BBBB), _w1); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr30, _MM_PERM_BBBB), _w1); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_CCCC), _w2); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr20, _MM_PERM_CCCC), _w2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr30, _MM_PERM_CCCC), _w2); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_DDDD), _w3); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr20, _MM_PERM_DDDD), _w3); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr30, _MM_PERM_DDDD), _w3); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_AAAA), _w4); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr21, _MM_PERM_AAAA), _w4); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr31, _MM_PERM_AAAA), _w4); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_BBBB), _w5); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr21, _MM_PERM_BBBB), _w5); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr31, _MM_PERM_BBBB), _w5); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_CCCC), _w6); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr21, _MM_PERM_CCCC), _w6); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr31, _MM_PERM_CCCC), _w6); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_DDDD), _w7); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr21, _MM_PERM_DDDD), _w7); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr31, _MM_PERM_DDDD), _w7); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr10, _MM_PERM_AAAA), _w0)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr20, _MM_PERM_AAAA), _w0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr30, _MM_PERM_AAAA), _w0)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr10, _MM_PERM_BBBB), _w1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr20, _MM_PERM_BBBB), _w1)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr30, _MM_PERM_BBBB), _w1)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr10, _MM_PERM_CCCC), _w2)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr20, _MM_PERM_CCCC), _w2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr30, _MM_PERM_CCCC), _w2)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr10, _MM_PERM_DDDD), _w3)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr20, _MM_PERM_DDDD), _w3)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr30, _MM_PERM_DDDD), _w3)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr11, _MM_PERM_AAAA), _w4)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr21, _MM_PERM_AAAA), _w4)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr31, _MM_PERM_AAAA), _w4)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr11, _MM_PERM_BBBB), _w5)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr21, _MM_PERM_BBBB), _w5)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr31, _MM_PERM_BBBB), _w5)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr11, _MM_PERM_CCCC), _w6)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr21, _MM_PERM_CCCC), _w6)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr31, _MM_PERM_CCCC), _w6)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr11, _MM_PERM_DDDD), _w7)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr21, _MM_PERM_DDDD), _w7)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr31, _MM_PERM_DDDD), _w7)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_AAAA), _w0); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr20, _MM_PERM_AAAA), _w0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr30, _MM_PERM_AAAA), _w0); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_BBBB), _w1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr20, _MM_PERM_BBBB), _w1); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr30, _MM_PERM_BBBB), _w1); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_CCCC), _w2); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr20, _MM_PERM_CCCC), _w2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr30, _MM_PERM_CCCC), _w2); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_DDDD), _w3); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr20, _MM_PERM_DDDD), _w3); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr30, _MM_PERM_DDDD), _w3); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_AAAA), _w4); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr21, _MM_PERM_AAAA), _w4); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr31, _MM_PERM_AAAA), _w4); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_BBBB), _w5); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr21, _MM_PERM_BBBB), _w5); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr31, _MM_PERM_BBBB), _w5); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_CCCC), _w6); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr21, _MM_PERM_CCCC), _w6); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr31, _MM_PERM_CCCC), _w6); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_DDDD), _w7); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr21, _MM_PERM_DDDD), _w7); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr31, _MM_PERM_DDDD), _w7); kptr += 256; } @@ -1180,41 +1145,22 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrrr2 = _mm512_broadcast_i32x4(_r2); __m512i _rrrr3 = _mm512_broadcast_i32x4(_r3); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_AAAA), _w0); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr2, _MM_PERM_AAAA), _w0); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr3, _MM_PERM_AAAA), _w0); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_BBBB), _w1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr2, _MM_PERM_BBBB), _w1); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr3, _MM_PERM_BBBB), _w1); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_CCCC), _w2); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr2, _MM_PERM_CCCC), _w2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr3, _MM_PERM_CCCC), _w2); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_DDDD), _w3); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr2, _MM_PERM_DDDD), _w3); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr3, _MM_PERM_DDDD), _w3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr1, _MM_PERM_AAAA), _w0)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr2, _MM_PERM_AAAA), _w0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr3, _MM_PERM_AAAA), _w0)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr1, _MM_PERM_BBBB), _w1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr2, _MM_PERM_BBBB), _w1)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr3, _MM_PERM_BBBB), _w1)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr1, _MM_PERM_CCCC), _w2)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr2, _MM_PERM_CCCC), _w2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr3, _MM_PERM_CCCC), _w2)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr1, _MM_PERM_DDDD), _w3)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr2, _MM_PERM_DDDD), _w3)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr3, _MM_PERM_DDDD), _w3)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_AAAA), _w0); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr2, _MM_PERM_AAAA), _w0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr3, _MM_PERM_AAAA), _w0); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_BBBB), _w1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr2, _MM_PERM_BBBB), _w1); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr3, _MM_PERM_BBBB), _w1); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_CCCC), _w2); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr2, _MM_PERM_CCCC), _w2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr3, _MM_PERM_CCCC), _w2); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_DDDD), _w3); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr2, _MM_PERM_DDDD), _w3); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr3, _MM_PERM_DDDD), _w3); kptr += 128; } @@ -1242,17 +1188,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _w = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*)kptr)); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _r0, _w); - _sum1 = _mm512_dpwssd_epi32(_sum1, _r1, _w); - _sum2 = _mm512_dpwssd_epi32(_sum2, _r2, _w); - _sum3 = _mm512_dpwssd_epi32(_sum3, _r3, _w); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_r0, _w)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_r1, _w)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_r2, _w)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_r3, _w)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _r0, _w); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _r1, _w); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _r2, _w); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _r3, _w); kptr += 32; } @@ -1413,41 +1352,22 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrrr10 = _mm512_broadcast_i32x4(_mm256_extracti128_si256(_rr1, 0)); __m512i _rrrr11 = _mm512_broadcast_i32x4(_mm256_extracti128_si256(_rr1, 1)); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_AAAA), _w0); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_BBBB), _w1); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_CCCC), _w2); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_DDDD), _w3); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_AAAA), _w4); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_BBBB), _w5); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_CCCC), _w6); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_DDDD), _w7); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr10, _MM_PERM_AAAA), _w0)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr10, _MM_PERM_BBBB), _w1)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr10, _MM_PERM_CCCC), _w2)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr10, _MM_PERM_DDDD), _w3)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr11, _MM_PERM_AAAA), _w4)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr11, _MM_PERM_BBBB), _w5)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr11, _MM_PERM_CCCC), _w6)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr11, _MM_PERM_DDDD), _w7)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_AAAA), _w0); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_BBBB), _w1); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_CCCC), _w2); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr10, _MM_PERM_DDDD), _w3); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_AAAA), _w4); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_BBBB), _w5); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_CCCC), _w6); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr11, _MM_PERM_DDDD), _w7); kptr += 256; } @@ -1491,25 +1411,14 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrrr0 = _mm512_broadcast_i32x4(_r0); __m512i _rrrr1 = _mm512_broadcast_i32x4(_r1); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_AAAA), _w0); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_BBBB), _w1); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_CCCC), _w2); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_DDDD), _w3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr1, _MM_PERM_AAAA), _w0)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr1, _MM_PERM_BBBB), _w1)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr1, _MM_PERM_CCCC), _w2)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr1, _MM_PERM_DDDD), _w3)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_AAAA), _w0); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_BBBB), _w1); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_CCCC), _w2); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr1, _MM_PERM_DDDD), _w3); kptr += 128; } @@ -1531,13 +1440,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _w = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*)kptr)); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _r0, _w); - _sum1 = _mm512_dpwssd_epi32(_sum1, _r1, _w); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_r0, _w)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_r1, _w)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _r0, _w); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _r1, _w); kptr += 32; } @@ -1664,25 +1568,14 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrrr00 = _mm512_broadcast_i32x4(_mm256_extracti128_si256(_rr0, 0)); __m512i _rrrr01 = _mm512_broadcast_i32x4(_mm256_extracti128_si256(_rr0, 1)); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3); - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3)); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_AAAA), _w0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_BBBB), _w1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_CCCC), _w2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr00, _MM_PERM_DDDD), _w3); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_AAAA), _w4); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_BBBB), _w5); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_CCCC), _w6); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr01, _MM_PERM_DDDD), _w7); kptr += 256; } @@ -1719,17 +1612,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const // 01234567 -> 01010101 01010101 01010101 01010101 __m512i _rrrr0 = _mm512_broadcast_i32x4(_r0); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0); - _sum1 = _mm512_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1); - _sum2 = _mm512_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2); - _sum3 = _mm512_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_AAAA), _w0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_BBBB), _w1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_CCCC), _w2); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _mm512_shuffle_epi32(_rrrr0, _MM_PERM_DDDD), _w3); kptr += 128; } @@ -1748,11 +1634,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _w = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*)kptr)); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _val, _w); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_val, _w)); -#endif + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _val, _w); kptr += 32; } @@ -1941,41 +1823,22 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrr3l = _mm512_unpacklo_epi64(_rrr3, _rrr3); __m512i _rrr3h = _mm512_unpackhi_epi64(_rrr3, _rrr3); -#if __AVX512VNNI__ - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); -#else - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(1, 1, 1, 1)), _w2)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(1, 1, 1, 1)), _w2)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(1, 1, 1, 1)), _w2)); - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif // __AVX512VNNI__ + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); kptr += 128; } @@ -2063,41 +1926,22 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _rr2 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r2), _r2, 1); __m256i _rr3 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r3), _r3, 1); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum2 = _mm256_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum3 = _mm256_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum0 = _mm256_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); - _sum1 = _mm256_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(1, 1, 1, 1)), _w1); - _sum2 = _mm256_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(1, 1, 1, 1)), _w1); - _sum3 = _mm256_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(1, 1, 1, 1)), _w1); - _sum0 = _mm256_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); - _sum1 = _mm256_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w2); - _sum2 = _mm256_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(2, 2, 2, 2)), _w2); - _sum3 = _mm256_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(2, 2, 2, 2)), _w2); - _sum0 = _mm256_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); - _sum1 = _mm256_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(3, 3, 3, 3)), _w3); - _sum2 = _mm256_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(3, 3, 3, 3)), _w3); - _sum3 = _mm256_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(3, 3, 3, 3)), _w3); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(3, 3, 3, 3)), _w3); kptr += 64; } @@ -2129,17 +1973,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _w = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)kptr)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _rr0, _w); - _sum1 = _mm256_dpwssd_epi32(_sum1, _rr1, _w); - _sum2 = _mm256_dpwssd_epi32(_sum2, _rr2, _w); - _sum3 = _mm256_dpwssd_epi32(_sum3, _rr3, _w); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_rr0, _w)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_rr1, _w)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_rr2, _w)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_rr3, _w)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _rr0, _w); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _rr1, _w); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _rr2, _w); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _rr3, _w); kptr += 16; } @@ -2336,25 +2173,14 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrr1l = _mm512_unpacklo_epi64(_rrr1, _rrr1); __m512i _rrr1h = _mm512_unpackhi_epi64(_rrr1, _rrr1); -#if __AVX512VNNI__ - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); -#else - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(1, 1, 1, 1)), _w2)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif // __AVX512VNNI__ + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); kptr += 128; } @@ -2422,25 +2248,14 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _rr0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r0), _r0, 1); __m256i _rr1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r1), _r1, 1); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum2 = _mm256_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); - _sum3 = _mm256_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(1, 1, 1, 1)), _w1); - _sum0 = _mm256_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); - _sum1 = _mm256_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w2); - _sum2 = _mm256_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); - _sum3 = _mm256_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(3, 3, 3, 3)), _w3); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(3, 3, 3, 3)), _w3); kptr += 64; } @@ -2464,13 +2279,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _w = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)kptr)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _rr0, _w); - _sum1 = _mm256_dpwssd_epi32(_sum1, _rr1, _w); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_rr0, _w)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_rr1, _w)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _rr0, _w); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _rr1, _w); kptr += 16; } @@ -2614,17 +2424,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrr0l = _mm512_unpacklo_epi64(_rrr0, _rrr0); __m512i _rrr0h = _mm512_unpackhi_epi64(_rrr0, _rrr0); -#if __AVX512VNNI__ - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); -#else - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif // __AVX512VNNI__ + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(1, 1, 1, 1)), _w2); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 3, 3, 3)), _w3); kptr += 128; } @@ -2683,17 +2486,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const // 01234567 -> 01010101 01010101 __m256i _rr0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r0), _r0, 1); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum1 = _mm256_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); - _sum2 = _mm256_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); - _sum3 = _mm256_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); kptr += 64; } @@ -2717,11 +2513,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _w = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)kptr)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _val, _w); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_val, _w)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _val, _w); kptr += 16; } @@ -2915,25 +2707,14 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrr3l = _mm512_unpacklo_epi64(_rrr3, _rrr3); __m512i _rrr3h = _mm512_unpackhi_epi64(_rrr3, _rrr3); -#if __AVX512VNNI__ - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); -#else - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 0, 2, 0)), _w0)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(2, 0, 2, 0)), _w0)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(2, 0, 2, 0)), _w0)); - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 1, 3, 1)), _w1)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(3, 1, 3, 1)), _w1)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(3, 1, 3, 1)), _w1)); -#endif // __AVX512VNNI__ + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr2l, _rrr2h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr3l, _rrr3h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); kptr += 64; } @@ -3037,26 +2818,15 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _rr2 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r2), _mm_shuffle_epi32(_r2, _MM_SHUFFLE(2, 3, 0, 1)), 1); __m256i _rr3 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r3), _mm_shuffle_epi32(_r3, _MM_SHUFFLE(2, 3, 0, 1)), 1); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00 = _mm256_dpwssd_epi32(_sum00, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum11 = _mm256_dpwssd_epi32(_sum11, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum22 = _mm256_dpwssd_epi32(_sum22, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum33 = _mm256_dpwssd_epi32(_sum33, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum00 = _mm256_dpwssd_epi32(_sum00, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum11 = _mm256_dpwssd_epi32(_sum11, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum22 = _mm256_dpwssd_epi32(_sum22, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum33 = _mm256_dpwssd_epi32(_sum33, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(2, 2, 2, 2)), _w1); -#else - _sum00 = _mm256_add_epi32(_sum00, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum11 = _mm256_add_epi32(_sum11, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum22 = _mm256_add_epi32(_sum22, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum33 = _mm256_add_epi32(_sum33, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum00 = _mm256_add_epi32(_sum00, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum11 = _mm256_add_epi32(_sum11, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum22 = _mm256_add_epi32(_sum22, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum33 = _mm256_add_epi32(_sum33, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); -#endif -#else // __AVX2__ + _sum00 = _mm256_comp_dpwssd_epi32(_sum00, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum11 = _mm256_comp_dpwssd_epi32(_sum11, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum22 = _mm256_comp_dpwssd_epi32(_sum22, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum33 = _mm256_comp_dpwssd_epi32(_sum33, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum00 = _mm256_comp_dpwssd_epi32(_sum00, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum11 = _mm256_comp_dpwssd_epi32(_sum11, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum22 = _mm256_comp_dpwssd_epi32(_sum22, _mm256_shuffle_epi32(_rr2, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum33 = _mm256_comp_dpwssd_epi32(_sum33, _mm256_shuffle_epi32(_rr3, _MM_SHUFFLE(2, 2, 2, 2)), _w1); +#else // __AVX2__ __m128i _w01 = _mm_load_si128((const __m128i*)kptr); __m128i _w23 = _mm_load_si128((const __m128i*)(kptr + 16)); __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); @@ -3067,41 +2837,22 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23); // 01234567 -> 01010101 -#if __XOP__ - _sum0 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0, _sum0); - _sum1 = _mm_maddd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(0, 0, 0, 0)), _w0, _sum1); - _sum2 = _mm_maddd_epi16(_mm_shuffle_epi32(_r2, _MM_SHUFFLE(0, 0, 0, 0)), _w0, _sum2); - _sum3 = _mm_maddd_epi16(_mm_shuffle_epi32(_r3, _MM_SHUFFLE(0, 0, 0, 0)), _w0, _sum3); - _sum0 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1, _sum0); - _sum1 = _mm_maddd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(1, 1, 1, 1)), _w1, _sum1); - _sum2 = _mm_maddd_epi16(_mm_shuffle_epi32(_r2, _MM_SHUFFLE(1, 1, 1, 1)), _w1, _sum2); - _sum3 = _mm_maddd_epi16(_mm_shuffle_epi32(_r3, _MM_SHUFFLE(1, 1, 1, 1)), _w1, _sum3); - _sum0 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2, _sum0); - _sum1 = _mm_maddd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(2, 2, 2, 2)), _w2, _sum1); - _sum2 = _mm_maddd_epi16(_mm_shuffle_epi32(_r2, _MM_SHUFFLE(2, 2, 2, 2)), _w2, _sum2); - _sum3 = _mm_maddd_epi16(_mm_shuffle_epi32(_r3, _MM_SHUFFLE(2, 2, 2, 2)), _w2, _sum3); - _sum0 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3, _sum0); - _sum1 = _mm_maddd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(3, 3, 3, 3)), _w3, _sum1); - _sum2 = _mm_maddd_epi16(_mm_shuffle_epi32(_r2, _MM_SHUFFLE(3, 3, 3, 3)), _w3, _sum2); - _sum3 = _mm_maddd_epi16(_mm_shuffle_epi32(_r3, _MM_SHUFFLE(3, 3, 3, 3)), _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_mm_shuffle_epi32(_r2, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_mm_shuffle_epi32(_r3, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_mm_shuffle_epi32(_r2, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_mm_shuffle_epi32(_r3, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_mm_shuffle_epi32(_r2, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_mm_shuffle_epi32(_r3, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_mm_shuffle_epi32(_r2, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_mm_shuffle_epi32(_r3, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif // __XOP__ + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _mm_shuffle_epi32(_r1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _mm_shuffle_epi32(_r2, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _mm_shuffle_epi32(_r3, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _mm_shuffle_epi32(_r1, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _mm_shuffle_epi32(_r2, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _mm_shuffle_epi32(_r3, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _mm_shuffle_epi32(_r1, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _mm_shuffle_epi32(_r2, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _mm_shuffle_epi32(_r3, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _mm_shuffle_epi32(_r1, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _mm_shuffle_epi32(_r2, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _mm_shuffle_epi32(_r3, _MM_SHUFFLE(3, 3, 3, 3)), _w3); #endif // __AVX2__ kptr += 32; @@ -3146,22 +2897,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); #endif -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm_dpwssd_epi32(_sum0, _r0, _w); - _sum1 = _mm_dpwssd_epi32(_sum1, _r1, _w); - _sum2 = _mm_dpwssd_epi32(_sum2, _r2, _w); - _sum3 = _mm_dpwssd_epi32(_sum3, _r3, _w); -#elif __XOP__ - _sum0 = _mm_maddd_epi16(_r0, _w, _sum0); - _sum1 = _mm_maddd_epi16(_r1, _w, _sum1); - _sum2 = _mm_maddd_epi16(_r2, _w, _sum2); - _sum3 = _mm_maddd_epi16(_r3, _w, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_r0, _w)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_r1, _w)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_r2, _w)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_r3, _w)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _r0, _w); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _r1, _w); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _r2, _w); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _r3, _w); kptr += 8; } @@ -3356,17 +3095,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrr1l = _mm512_unpacklo_epi64(_rrr1, _rrr1); __m512i _rrr1h = _mm512_unpackhi_epi64(_rrr1, _rrr1); -#if __AVX512VNNI__ - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); - _sum22 = _mm512_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); - _sum33 = _mm512_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); -#else - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 0, 2, 0)), _w0)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 1, 3, 1)), _w1)); -#endif // __AVX512VNNI__ + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _mm512_shuffle_i32x4(_rrr1l, _rrr1h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); kptr += 64; } @@ -3448,18 +3180,11 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _rr0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r0), _mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 3, 0, 1)), 1); __m256i _rr1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r1), _mm_shuffle_epi32(_r1, _MM_SHUFFLE(2, 3, 0, 1)), 1); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00 = _mm256_dpwssd_epi32(_sum00, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum11 = _mm256_dpwssd_epi32(_sum11, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum22 = _mm256_dpwssd_epi32(_sum22, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1); - _sum33 = _mm256_dpwssd_epi32(_sum33, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w1); -#else - _sum00 = _mm256_add_epi32(_sum00, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum11 = _mm256_add_epi32(_sum11, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum22 = _mm256_add_epi32(_sum22, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); - _sum33 = _mm256_add_epi32(_sum33, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); -#endif -#else // __AVX2__ + _sum00 = _mm256_comp_dpwssd_epi32(_sum00, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum11 = _mm256_comp_dpwssd_epi32(_sum11, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum22 = _mm256_comp_dpwssd_epi32(_sum22, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1); + _sum33 = _mm256_comp_dpwssd_epi32(_sum33, _mm256_shuffle_epi32(_rr1, _MM_SHUFFLE(2, 2, 2, 2)), _w1); +#else // __AVX2__ __m128i _w01 = _mm_load_si128((const __m128i*)kptr); __m128i _w23 = _mm_load_si128((const __m128i*)(kptr + 16)); __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); @@ -3470,25 +3195,14 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23); // 01234567 -> 01010101 -#if __XOP__ - _sum0 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0, _sum0); - _sum1 = _mm_maddd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(0, 0, 0, 0)), _w0, _sum1); - _sum2 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1, _sum2); - _sum3 = _mm_maddd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(1, 1, 1, 1)), _w1, _sum3); - _sum0 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2, _sum0); - _sum1 = _mm_maddd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(2, 2, 2, 2)), _w2, _sum1); - _sum2 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3, _sum2); - _sum3 = _mm_maddd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(3, 3, 3, 3)), _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_mm_shuffle_epi32(_r1, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif // __XOP__ + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _mm_shuffle_epi32(_r1, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _mm_shuffle_epi32(_r1, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _mm_shuffle_epi32(_r1, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _mm_shuffle_epi32(_r1, _MM_SHUFFLE(3, 3, 3, 3)), _w3); #endif // __AVX2__ kptr += 32; @@ -3527,16 +3241,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); #endif -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm_dpwssd_epi32(_sum0, _r0, _w); - _sum1 = _mm_dpwssd_epi32(_sum1, _r1, _w); -#elif __XOP__ - _sum0 = _mm_maddd_epi16(_r0, _w, _sum0); - _sum1 = _mm_maddd_epi16(_r1, _w, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_r0, _w)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_r1, _w)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _r0, _w); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _r1, _w); kptr += 8; } @@ -3681,13 +3387,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _rrr0l = _mm512_unpacklo_epi64(_rrr0, _rrr0); __m512i _rrr0h = _mm512_unpackhi_epi64(_rrr0, _rrr0); -#if __AVX512VNNI__ - _sum00 = _mm512_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); - _sum11 = _mm512_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); -#else - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1)); -#endif // __AVX512VNNI__ + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(2, 0, 2, 0)), _w0); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _mm512_shuffle_i32x4(_rrr0l, _rrr0h, _MM_SHUFFLE(3, 1, 3, 1)), _w1); kptr += 64; } @@ -3752,14 +3453,9 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const // 01234567 -> 01010101 23232323 __m256i _rr0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r0), _mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 3, 0, 1)), 1); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00 = _mm256_dpwssd_epi32(_sum00, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); - _sum11 = _mm256_dpwssd_epi32(_sum11, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1); -#else - _sum00 = _mm256_add_epi32(_sum00, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum11 = _mm256_add_epi32(_sum11, _mm256_madd_epi16(_mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1)); -#endif -#else // __AVX2__ + _sum00 = _mm256_comp_dpwssd_epi32(_sum00, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum11 = _mm256_comp_dpwssd_epi32(_sum11, _mm256_shuffle_epi32(_rr0, _MM_SHUFFLE(2, 2, 2, 2)), _w1); +#else // __AVX2__ __m128i _w01 = _mm_load_si128((const __m128i*)kptr); __m128i _w23 = _mm_load_si128((const __m128i*)(kptr + 16)); __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); @@ -3770,17 +3466,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m128i _w3 = _mm_unpackhi_epi8(_w23, _extw23); // 01234567 -> 01010101 -#if __XOP__ - _sum0 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0, _sum0); - _sum1 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1, _sum1); - _sum2 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2, _sum2); - _sum3 = _mm_maddd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3)); -#endif // __XOP__ + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(0, 0, 0, 0)), _w0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(1, 1, 1, 1)), _w1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(2, 2, 2, 2)), _w2); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _mm_shuffle_epi32(_r0, _MM_SHUFFLE(3, 3, 3, 3)), _w3); #endif // __AVX2__ kptr += 32; @@ -3815,13 +3504,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); #endif -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm_dpwssd_epi32(_sum0, _r0, _w); -#elif __XOP__ - _sum0 = _mm_maddd_epi16(_r0, _w, _sum0); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_r0, _w)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _r0, _w); kptr += 8; } @@ -3996,17 +3679,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _w = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*)kptr)); -#if __AVX512VNNI__ - _sum00 = _mm512_dpwssd_epi32(_sum00, _valval0, _w); - _sum11 = _mm512_dpwssd_epi32(_sum11, _valval1, _w); - _sum22 = _mm512_dpwssd_epi32(_sum22, _valval2, _w); - _sum33 = _mm512_dpwssd_epi32(_sum33, _valval3, _w); -#else - _sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_valval0, _w)); - _sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_valval1, _w)); - _sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_valval2, _w)); - _sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_valval3, _w)); -#endif // __AVX512VNNI__ + _sum00 = _mm512_comp_dpwssd_epi32(_sum00, _valval0, _w); + _sum11 = _mm512_comp_dpwssd_epi32(_sum11, _valval1, _w); + _sum22 = _mm512_comp_dpwssd_epi32(_sum22, _valval2, _w); + _sum33 = _mm512_comp_dpwssd_epi32(_sum33, _valval3, _w); kptr += 32; } @@ -4115,41 +3791,23 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _w = _mm256_cvtepi8_epi16(_w01); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00 = _mm256_dpwssd_epi32(_sum00, _valval0, _w); - _sum11 = _mm256_dpwssd_epi32(_sum11, _valval1, _w); - _sum22 = _mm256_dpwssd_epi32(_sum22, _valval2, _w); - _sum33 = _mm256_dpwssd_epi32(_sum33, _valval3, _w); -#else - _sum00 = _mm256_add_epi32(_sum00, _mm256_madd_epi16(_valval0, _w)); - _sum11 = _mm256_add_epi32(_sum11, _mm256_madd_epi16(_valval1, _w)); - _sum22 = _mm256_add_epi32(_sum22, _mm256_madd_epi16(_valval2, _w)); - _sum33 = _mm256_add_epi32(_sum33, _mm256_madd_epi16(_valval3, _w)); -#endif -#else // __AVX2__ + _sum00 = _mm256_comp_dpwssd_epi32(_sum00, _valval0, _w); + _sum11 = _mm256_comp_dpwssd_epi32(_sum11, _valval1, _w); + _sum22 = _mm256_comp_dpwssd_epi32(_sum22, _valval2, _w); + _sum33 = _mm256_comp_dpwssd_epi32(_sum33, _valval3, _w); +#else // __AVX2__ __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); -#if __XOP__ - _sum00 = _mm_maddd_epi16(_r0, _w0, _sum00); - _sum10 = _mm_maddd_epi16(_r0, _w1, _sum10); - _sum01 = _mm_maddd_epi16(_r1, _w0, _sum01); - _sum11 = _mm_maddd_epi16(_r1, _w1, _sum11); - _sum02 = _mm_maddd_epi16(_r2, _w0, _sum02); - _sum12 = _mm_maddd_epi16(_r2, _w1, _sum12); - _sum03 = _mm_maddd_epi16(_r3, _w0, _sum03); - _sum13 = _mm_maddd_epi16(_r3, _w1, _sum13); -#else - _sum00 = _mm_add_epi32(_sum00, _mm_madd_epi16(_r0, _w0)); - _sum10 = _mm_add_epi32(_sum10, _mm_madd_epi16(_r0, _w1)); - _sum01 = _mm_add_epi32(_sum01, _mm_madd_epi16(_r1, _w0)); - _sum11 = _mm_add_epi32(_sum11, _mm_madd_epi16(_r1, _w1)); - _sum02 = _mm_add_epi32(_sum02, _mm_madd_epi16(_r2, _w0)); - _sum12 = _mm_add_epi32(_sum12, _mm_madd_epi16(_r2, _w1)); - _sum03 = _mm_add_epi32(_sum03, _mm_madd_epi16(_r3, _w0)); - _sum13 = _mm_add_epi32(_sum13, _mm_madd_epi16(_r3, _w1)); -#endif // __XOP__ + _sum00 = _mm_comp_dpwssd_epi32(_sum00, _r0, _w0); + _sum10 = _mm_comp_dpwssd_epi32(_sum10, _r0, _w1); + _sum01 = _mm_comp_dpwssd_epi32(_sum01, _r1, _w0); + _sum11 = _mm_comp_dpwssd_epi32(_sum11, _r1, _w1); + _sum02 = _mm_comp_dpwssd_epi32(_sum02, _r2, _w0); + _sum12 = _mm_comp_dpwssd_epi32(_sum12, _r2, _w1); + _sum03 = _mm_comp_dpwssd_epi32(_sum03, _r3, _w0); + _sum13 = _mm_comp_dpwssd_epi32(_sum13, _r3, _w1); #endif // __AVX2__ kptr += 16; @@ -4209,13 +3867,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m128i _w0 = _mm_setr_epi16(kptr[0], kptr[2], kptr[0], kptr[2], kptr[0], kptr[2], kptr[0], kptr[2]); __m128i _w1 = _mm_setr_epi16(kptr[1], kptr[3], kptr[1], kptr[3], kptr[1], kptr[3], kptr[1], kptr[3]); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_r, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_r, _w1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_r, _w0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_r, _w1)); -#endif // __XOP__ + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _r, _w0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _r, _w1); kptr += 4; } @@ -4328,13 +3981,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _w = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*)kptr)); -#if __AVX512VNNI__ - _sum0 = _mm512_dpwssd_epi32(_sum0, _valval0, _w); - _sum1 = _mm512_dpwssd_epi32(_sum1, _valval1, _w); -#else - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_valval0, _w)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_valval1, _w)); -#endif // __AVX512VNNI__ + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _valval0, _w); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _valval1, _w); kptr += 32; } @@ -4414,29 +4062,17 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _valval0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r0), _r0, 1); __m256i _valval1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r1), _r1, 1); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _valval0, _w); - _sum1 = _mm256_dpwssd_epi32(_sum1, _valval1, _w); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_valval0, _w)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_valval1, _w)); -#endif -#else // __AVX2__ + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _valval0, _w); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _valval1, _w); +#else // __AVX2__ __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); -#if __XOP__ - _sum00 = _mm_maddd_epi16(_r0, _w0, _sum00); - _sum10 = _mm_maddd_epi16(_r0, _w1, _sum10); - _sum01 = _mm_maddd_epi16(_r1, _w0, _sum01); - _sum11 = _mm_maddd_epi16(_r1, _w1, _sum11); -#else - _sum00 = _mm_add_epi32(_sum00, _mm_madd_epi16(_r0, _w0)); - _sum10 = _mm_add_epi32(_sum10, _mm_madd_epi16(_r0, _w1)); - _sum01 = _mm_add_epi32(_sum01, _mm_madd_epi16(_r1, _w0)); - _sum11 = _mm_add_epi32(_sum11, _mm_madd_epi16(_r1, _w1)); -#endif // __XOP__ + _sum00 = _mm_comp_dpwssd_epi32(_sum00, _r0, _w0); + _sum10 = _mm_comp_dpwssd_epi32(_sum10, _r0, _w1); + _sum01 = _mm_comp_dpwssd_epi32(_sum01, _r1, _w0); + _sum11 = _mm_comp_dpwssd_epi32(_sum11, _r1, _w1); #endif // __AVX2__ kptr += 16; @@ -4563,11 +4199,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m512i _w = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*)kptr)); -#if __AVX512VNNI__ - _sum01 = _mm512_dpwssd_epi32(_sum01, _valval, _w); -#else - _sum01 = _mm512_add_epi32(_sum01, _mm512_madd_epi16(_valval, _w)); -#endif + _sum01 = _mm512_comp_dpwssd_epi32(_sum01, _valval, _w); kptr += 32; } @@ -4629,23 +4261,14 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _rr0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_r0), _r0, 1); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum = _mm256_dpwssd_epi32(_sum, _rr0, _w); -#else - _sum = _mm256_add_epi32(_sum, _mm256_madd_epi16(_rr0, _w)); -#endif -#else // __AVX2__ + _sum = _mm256_comp_dpwssd_epi32(_sum, _rr0, _w); +#else // __AVX2__ __m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01); __m128i _w0 = _mm_unpacklo_epi8(_w01, _extw01); __m128i _w1 = _mm_unpackhi_epi8(_w01, _extw01); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_r0, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_r0, _w1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_r0, _w0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_r0, _w1)); -#endif // __XOP__ + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _r0, _w0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _r0, _w1); #endif // __AVX2__ kptr += 16; @@ -4796,17 +4419,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _w = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)kptr)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _val0, _w); - _sum1 = _mm256_dpwssd_epi32(_sum1, _val1, _w); - _sum2 = _mm256_dpwssd_epi32(_sum2, _val2, _w); - _sum3 = _mm256_dpwssd_epi32(_sum3, _val3, _w); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_val0, _w)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_val1, _w)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_val2, _w)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_val3, _w)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _val0, _w); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _val1, _w); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _val2, _w); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _val3, _w); kptr += 16; } @@ -4898,22 +4514,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); #endif -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm_dpwssd_epi32(_sum0, _r0, _w); - _sum1 = _mm_dpwssd_epi32(_sum1, _r1, _w); - _sum2 = _mm_dpwssd_epi32(_sum2, _r2, _w); - _sum3 = _mm_dpwssd_epi32(_sum3, _r3, _w); -#elif __XOP__ - _sum0 = _mm_maddd_epi16(_r0, _w, _sum0); - _sum1 = _mm_maddd_epi16(_r1, _w, _sum1); - _sum2 = _mm_maddd_epi16(_r2, _w, _sum2); - _sum3 = _mm_maddd_epi16(_r3, _w, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_r0, _w)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_r1, _w)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_r2, _w)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_r3, _w)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _r0, _w); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _r1, _w); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _r2, _w); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _r3, _w); kptr += 8; } @@ -4956,11 +4560,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m128i _r = _mm_setr_epi16(r0s[0], r0s[N], r1s[0], r1s[N], r2s[0], r2s[N], r3s[0], r3s[N]); __m128i _w = _mm_setr_epi16(kptr[0], kptr[1], kptr[0], kptr[1], kptr[0], kptr[1], kptr[0], kptr[1]); -#if __XOP__ - _sum = _mm_maddd_epi16(_r, _w, _sum); -#else - _sum = _mm_add_epi32(_sum, _mm_madd_epi16(_r, _w)); -#endif // __XOP__ + _sum = _mm_comp_dpwssd_epi32(_sum, _r, _w); kptr += 2; } @@ -5073,13 +4673,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _w = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)kptr)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm256_dpwssd_epi32(_sum0, _val0, _w); - _sum1 = _mm256_dpwssd_epi32(_sum1, _val1, _w); -#else - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_val0, _w)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_val1, _w)); -#endif + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _val0, _w); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _val1, _w); kptr += 16; } @@ -5149,16 +4744,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); #endif -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0 = _mm_dpwssd_epi32(_sum0, _r0, _w); - _sum1 = _mm_dpwssd_epi32(_sum1, _r1, _w); -#elif __XOP__ - _sum0 = _mm_maddd_epi16(_r0, _w, _sum0); - _sum1 = _mm_maddd_epi16(_r1, _w, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_r0, _w)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_r1, _w)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _r0, _w); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _r1, _w); kptr += 8; } @@ -5264,11 +4851,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const __m256i _w = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)kptr)); -#if __AVXVNNI__ || __AVX512VNNI__ - _sum = _mm256_dpwssd_epi32(_sum, _val, _w); -#else - _sum = _mm256_add_epi32(_sum, _mm256_madd_epi16(_val, _w)); -#endif + _sum = _mm256_comp_dpwssd_epi32(_sum, _val, _w); kptr += 16; } @@ -5324,13 +4907,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); #endif -#if __AVXVNNI__ || __AVX512VNNI__ - _sum = _mm_dpwssd_epi32(_sum, _r0, _w); -#elif __XOP__ - _sum = _mm_maddd_epi16(_r0, _w, _sum); -#else - _sum = _mm_add_epi32(_sum, _mm_madd_epi16(_r0, _w)); -#endif + _sum = _mm_comp_dpwssd_epi32(_sum, _r0, _w); kptr += 8; } diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h new file mode 100644 index 000000000000..f9e0050fd553 --- /dev/null +++ b/src/layer/x86/gemm_int8.h @@ -0,0 +1,15646 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ +void pack_A_tile_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void pack_A_tile_fp32_to_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp32_to_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void gemm_transB_packed_tile_int8_avx512vnni(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void pack_A_tile_int8_avxvnniint8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_int8_avxvnniint8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_avxvnniint8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_avxvnniint8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void pack_A_tile_fp32_to_int8_avxvnniint8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp32_to_int8_avxvnniint8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_avxvnniint8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_avxvnniint8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void gemm_transB_packed_tile_int8_avxvnniint8(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void pack_A_tile_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void pack_A_tile_fp32_to_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp32_to_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void gemm_transB_packed_tile_int8_avxvnni(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void unpack_output_tile_int32_to_fp32_avx2(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose); +void gemm_transB_packed_tile_int8_avx2(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void gemm_transB_packed_tile_int8_xop(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + pack_A_tile_int8_avx512vnni(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + pack_A_tile_int8_avxvnniint8(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + pack_A_tile_int8_avxvnni(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + pack_A_tile_int8_avx2(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("pack_A_tile_int8"); + // assert A.elempack == 1 + // assert A.dims == 2 + + signed char* pp = AT; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const signed char* p0 = A.row(i + ii) + k; + + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A.w)); + + int kk = 0; +#if __AVX512VNNI__ + __m512i _w_shift = _mm512_setzero_si512(); + __m512i _v127 = _mm512_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _p = _mm512_i32gather_epi32(_vindex, p0, sizeof(signed char)); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _p); + _mm512_storeu_si512((__m512i*)pp, _p); + pp += 64; + p0 += 4; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _p = _mm512_cvtepi32_epi16(_mm512_i32gather_epi32(_vindex, p0, sizeof(signed char))); + _mm256_storeu_si256((__m256i*)pp, _p); + pp += 32; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + __m128i _p = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, p0, sizeof(signed char))); + _mm_store_si128((__m128i*)pp, _p); + pp += 16; + p0++; + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + const signed char* p0 = A.row(i + ii) + k; + + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(A.w)); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _mm256_storeu_si256((__m256i*)pp, _p); + pp += 32; + p0 += 4; + } +#else // __AVXVNNIINT8__ + __m256i _w_shift = _mm256_setzero_si256(); + __m256i _v127 = _mm256_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _p); + _mm256_storeu_si256((__m256i*)pp, _p); + pp += 32; + p0 += 4; + } + if (max_kk >= 4) + { + _mm256_storeu_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _p = _mm256_comp_cvtepi32_epi16(_mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _p); +#else + _mm_storeu_si128((__m128i*)pp, _p); +#endif + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + __m128i _p = _mm256_comp_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0++; + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* p0 = A.row(i + ii) + k; + +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(A.w)); +#else + const signed char* p1 = A.row(i + ii + 1) + k; + const signed char* p2 = A.row(i + ii + 2) + k; + const signed char* p3 = A.row(i + ii + 3) + k; +#endif // __AVX2__ + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _mm_storeu_si128((__m128i*)pp, _p); + pp += 16; + p0 += 4; + } +#else // __AVXVNNIINT8__ + __m128i _w_shift = _mm_setzero_si128(); + __m128i _v127 = _mm_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _p); + _mm_storeu_si128((__m128i*)pp, _p); + pp += 16; + p0 += 4; + } + if (max_kk >= 4) + { + _mm_storeu_si128((__m128i*)pp, _w_shift); + pp += 16; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m128i _p = _mm_comp_cvtepi32_epi16(_mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0 += 2; +#else + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp += 8; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; +#endif // __AVX2__ + } + for (; kk < max_kk; kk++) + { +#if __AVX2__ + __m128i _p = _mm_comp_cvtepi32_epi8(_mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_store_ss((float*)pp, _mm_castsi128_ps(_p)); + pp += 4; + p0++; +#else + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; +#endif // __AVX2__ + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* p0 = A.row(i + ii) + k; + const signed char* p1 = A.row(i + ii + 1) + k; + + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp += 8; + p0 += 4; + p1 += 4; + } +#else // __AVXVNNIINT8__ + int w_shift0 = 0; + int w_shift1 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + pp += 8; + p0 += 4; + p1 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp += 4; + p0 += 2; + p1 += 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + for (; ii < max_ii; ii += 1) + { + const signed char* p0 = A.row(i + ii) + k; + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp += 4; + p0 += 4; + } +#else // __AVXVNNIINT8__ + int w_shift = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + pp += 4; + p0 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } +} + +static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + transpose_pack_A_tile_int8_avx512vnni(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + transpose_pack_A_tile_int8_avxvnniint8(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + transpose_pack_A_tile_int8_avxvnni(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + transpose_pack_A_tile_int8_avx2(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("transpose_pack_A_tile_int8"); + // assert A.elempack == 1 + // assert A.dims == 2 + + const int A_hstep = A.w; + + signed char* pp = AT; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __AVX512VNNI__ + __m512i _w_shift = _mm512_setzero_si512(); + __m512i _v127 = _mm512_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep)); + __m128i _p2 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep * 2)); + __m128i _p3 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep * 3)); + transpose16x4_epi8(_p0, _p1, _p2, _p3); + __m512i _pp = combine4x4_epi32(_p0, _p1, _p2, _p3); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + _mm512_storeu_si512((__m512i*)pp, _pp); + pp += 64; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep)); + __m128i _t0 = _mm_unpacklo_epi8(_p0, _p1); + __m128i _t1 = _mm_unpackhi_epi8(_p0, _p1); + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + pp += 32; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)p0); + _mm_store_si128((__m128i*)pp, _p); + pp += 16; + p0 += A_hstep; + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep)); + __m128i _p2 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 2)); + __m128i _p3 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 3)); + transpose8x4_epi8(_p0, _p1, _p2, _p3); + __m256i _pp = combine4x2_epi32(_p0, _p1); + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; + p0 += A_hstep * 4; + } +#else // __AVXVNNIINT8__ + __m256i _w_shift = _mm256_setzero_si256(); + __m256i _v127 = _mm256_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep)); + __m128i _p2 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 2)); + __m128i _p3 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 3)); + transpose8x4_epi8(_p0, _p1, _p2, _p3); + __m256i _pp = combine4x2_epi32(_p0, _p1); + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp); + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + _mm256_storeu_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep)); + __m128i _pp = _mm_unpacklo_epi8(_p0, _p1); +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m128i _p = _mm_loadl_epi64((const __m128i*)p0); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0 += A_hstep; + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* p0 = A.row(k) + (i + ii); + +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(A_hstep)); +#endif // __AVX512VNNI__ || __AVXVNNI__ + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pp = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; + p0 += A_hstep * 4; + } +#else // __AVXVNNIINT8__ + __m128i _w_shift = _mm_setzero_si128(); + __m128i _v127 = _mm_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pp = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp); + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + _mm_storeu_si128((__m128i*)pp, _w_shift); + pp += 16; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[1]; + pp[3] = p0[A_hstep + 1]; + pp[4] = p0[2]; + pp[5] = p0[A_hstep + 2]; + pp[6] = p0[3]; + pp[7] = p0[A_hstep + 3]; + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp += 4; + p0 += A_hstep; + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + pp += 8; + p0 += A_hstep * 4; + } +#else // __AVXVNNIINT8__ + int w_shift0 = 0; + int w_shift1 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + pp += 8; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[1]; + pp[3] = p0[A_hstep + 1]; + pp += 4; + p0 += A_hstep * 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += A_hstep; + } + } + for (; ii < max_ii; ii += 1) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp += 4; + p0 += A_hstep * 4; + } +#else // __AVXVNNIINT8__ + int w_shift = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + pp += 4; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += A_hstep; + } + } +} + +static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + pack_B_tile_int8_avx512vnni(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + pack_B_tile_int8_avxvnniint8(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + pack_B_tile_int8_avxvnni(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + pack_B_tile_int8_avx2(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("pack_B_tile_int8"); + // assert B.elempack == 1 + // assert B.dims == 2 + + signed char* pp = BT; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* p0 = B.row(j + jj) + k; + + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(B.w)); + + int kk = 0; +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _p = _mm512_i32gather_epi32(_vindex, p0, sizeof(signed char)); + _p = _mm512_add_epi8(_p, _v127); + _mm512_storeu_si512((__m512i*)pp, _p); + pp += 64; + p0 += 4; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _p = _mm512_cvtepi32_epi16(_mm512_i32gather_epi32(_vindex, p0, sizeof(signed char))); + _mm256_storeu_si256((__m256i*)pp, _p); + pp += 32; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + __m128i _p = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, p0, sizeof(signed char))); + _mm_store_si128((__m128i*)pp, _p); + pp += 16; + p0++; + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* p0 = B.row(j + jj) + k; + +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(B.w)); +#else + const signed char* p1 = B.row(j + jj + 1) + k; + const signed char* p2 = B.row(j + jj + 2) + k; + const signed char* p3 = B.row(j + jj + 3) + k; + const signed char* p4 = B.row(j + jj + 4) + k; + const signed char* p5 = B.row(j + jj + 5) + k; + const signed char* p6 = B.row(j + jj + 6) + k; + const signed char* p7 = B.row(j + jj + 7) + k; +#endif // __AVX2__ + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _mm256_storeu_si256((__m256i*)pp, _p); + pp += 32; + p0 += 4; + } +#else // __AVXVNNIINT8__ + __m256i _v127 = _mm256_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _p = _mm256_add_epi8(_p, _v127); + _mm256_storeu_si256((__m256i*)pp, _p); + pp += 32; + p0 += 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m128i _p = _mm256_comp_cvtepi32_epi16(_mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _p); +#else + _mm_storeu_si128((__m128i*)pp, _p); +#endif + pp += 16; + p0 += 2; +#else + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp += 16; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + p4 += 2; + p5 += 2; + p6 += 2; + p7 += 2; +#endif // __AVX2__ + } + for (; kk < max_kk; kk++) + { +#if __AVX2__ + __m128i _p = _mm256_comp_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0++; +#else + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; +#endif // __AVX2__ + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* p0 = B.row(j + jj) + k; + +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(B.w)); +#else + const signed char* p1 = B.row(j + jj + 1) + k; + const signed char* p2 = B.row(j + jj + 2) + k; + const signed char* p3 = B.row(j + jj + 3) + k; +#endif // __AVX2__ + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _mm_storeu_si128((__m128i*)pp, _p); + pp += 16; + p0 += 4; + } +#else // __AVXVNNIINT8__ + __m128i _v127 = _mm_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _p = _mm_add_epi8(_p, _v127); + _mm_storeu_si128((__m128i*)pp, _p); + pp += 16; + p0 += 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m128i _p = _mm_comp_cvtepi32_epi16(_mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0 += 2; +#else + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp += 8; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; +#endif // __AVX2__ + } + for (; kk < max_kk; kk++) + { +#if __AVX2__ + __m128i _p = _mm_comp_cvtepi32_epi8(_mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_store_ss((float*)pp, _mm_castsi128_ps(_p)); + pp += 4; + p0++; +#else + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; +#endif // __AVX2__ + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* p0 = B.row(j + jj) + k; + const signed char* p1 = B.row(j + jj + 1) + k; + + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp += 8; + p0 += 4; + p1 += 4; + } +#else // __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[1] + 127; + pp[2] = p0[2] + 127; + pp[3] = p0[3] + 127; + pp[4] = p1[0] + 127; + pp[5] = p1[1] + 127; + pp[6] = p1[2] + 127; + pp[7] = p1[3] + 127; + pp += 8; + p0 += 4; + p1 += 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp += 4; + p0 += 2; + p1 += 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + for (; jj < max_jj; jj += 1) + { + const signed char* p0 = B.row(j + jj) + k; + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp += 4; + p0 += 4; + } +#else // __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[1] + 127; + pp[2] = p0[2] + 127; + pp[3] = p0[3] + 127; + pp += 4; + p0 += 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } +} + +static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + transpose_pack_B_tile_int8_avx512vnni(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + transpose_pack_B_tile_int8_avxvnniint8(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + transpose_pack_B_tile_int8_avxvnni(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + transpose_pack_B_tile_int8_avx2(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("transpose_pack_B_tile_int8"); + // assert B.elempack == 1 + // assert B.dims == 2 + + const int B_hstep = B.w; + + signed char* pp = BT; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep)); + __m128i _p2 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep * 2)); + __m128i _p3 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep * 3)); + transpose16x4_epi8(_p0, _p1, _p2, _p3); + __m512i _pp = combine4x4_epi32(_p0, _p1, _p2, _p3); + _pp = _mm512_add_epi8(_pp, _v127); + _mm512_storeu_si512((__m512i*)pp, _pp); + pp += 64; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep)); + __m128i _t0 = _mm_unpacklo_epi8(_p0, _p1); + __m128i _t1 = _mm_unpackhi_epi8(_p0, _p1); + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + pp += 32; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m128i _p = _mm_loadu_si128((const __m128i*)p0); + _mm_store_si128((__m128i*)pp, _p); + pp += 16; + p0 += B_hstep; + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep)); + __m128i _p2 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 2)); + __m128i _p3 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 3)); + transpose8x4_epi8(_p0, _p1, _p2, _p3); + __m256i _pp = combine4x2_epi32(_p0, _p1); + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; + p0 += B_hstep * 4; + } +#else // __AVXVNNIINT8__ + __m256i _v127 = _mm256_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep)); + __m128i _p2 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 2)); + __m128i _p3 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 3)); + transpose8x4_epi8(_p0, _p1, _p2, _p3); + __m256i _pp = combine4x2_epi32(_p0, _p1); + _pp = _mm256_add_epi8(_pp, _v127); + _mm256_storeu_si256((__m256i*)pp, _pp); + pp += 32; + p0 += B_hstep * 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep)); + __m128i _pp = _mm_unpacklo_epi8(_p0, _p1); +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m128i _p = _mm_loadl_epi64((const __m128i*)p0); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0 += B_hstep; + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* p0 = B.row(k) + (j + jj); + +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(B_hstep)); +#endif // __AVX512VNNI__ || __AVXVNNI__ + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pp = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; + p0 += B_hstep * 4; + } +#else // __AVXVNNIINT8__ + __m128i _v127 = _mm_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pp = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + _pp = _mm_add_epi8(_pp, _v127); + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; + p0 += B_hstep * 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[1]; + pp[3] = p0[B_hstep + 1]; + pp[4] = p0[2]; + pp[5] = p0[B_hstep + 2]; + pp[6] = p0[3]; + pp[7] = p0[B_hstep + 3]; + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp += 4; + p0 += B_hstep; + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[B_hstep + 1]; + pp[6] = p0[B_hstep * 2 + 1]; + pp[7] = p0[B_hstep * 3 + 1]; + pp += 8; + p0 += B_hstep * 4; + } +#else // __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[B_hstep] + 127; + pp[2] = p0[B_hstep * 2] + 127; + pp[3] = p0[B_hstep * 3] + 127; + pp[4] = p0[1] + 127; + pp[5] = p0[B_hstep + 1] + 127; + pp[6] = p0[B_hstep * 2 + 1] + 127; + pp[7] = p0[B_hstep * 3 + 1] + 127; + pp += 8; + p0 += B_hstep * 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[1]; + pp[3] = p0[B_hstep + 1]; + pp += 4; + p0 += B_hstep * 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += B_hstep; + } + } + for (; jj < max_jj; jj += 1) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ +#if __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp += 4; + p0 += B_hstep * 4; + } +#else // __AVXVNNIINT8__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[B_hstep] + 127; + pp[2] = p0[B_hstep * 2] + 127; + pp[3] = p0[B_hstep * 3] + 127; + pp += 4; + p0 += B_hstep * 4; + } +#endif // __AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += B_hstep; + } + } +} + +static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("compute_A_tile_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + + float* ps = (float*)scales + i; + float* pods = (float*)out_descales + i; + + const int max_ii_packed = max_ii / elempack; + const int size = A.w * elempack; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _v127_avx512 = _mm512_set1_ps(127.f); + __m512 _v127_B_scale_avx512 = _mm512_set1_ps(v127_B_scale); +#endif // __AVX512F__ + __m256 _v127_avx = _mm256_set1_ps(127.f); + __m256 _v127_B_scale_avx = _mm256_set1_ps(v127_B_scale); +#endif // __AVX__ + __m128 _v127 = _mm_set1_ps(127.f); + __m128 _v127_B_scale = _mm_set1_ps(v127_B_scale); +#endif // __SSE2__ + + for (int ii = 0; ii < max_ii_packed; ii++) + { + const float* ptr = (const float*)A + (i + ii * elempack) * A_hstep; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _absmax_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ + __m256 _absmax_avx = _mm256_set1_ps(0.f); +#endif // __AVX__ + __m128 _absmax = _mm_set1_ps(0.f); +#endif // __SSE2__ + float absmax = 0.f; + + int kk = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; kk + 15 < size; kk += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; kk + 7 < size; kk += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + _absmax_avx = _mm256_max_ps(_absmax_avx, abs256_ps(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; kk + 3 < size; kk += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + _absmax = _mm_max_ps(_absmax, abs_ps(_p)); + ptr += 4; + } +#endif // __SSE2__ + for (; kk < size; kk++) + { + absmax = std::max(absmax, (float)fabsf(*ptr)); + ptr++; + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scale = _mm512_div_ps(_v127_avx512, _absmax_avx512); + __m512 _out_descale = _mm512_div_ps(_absmax_avx512, _v127_B_scale_avx512); + _mm512_store_ps(ps, _scale); + _mm512_store_ps(pods, _out_descale); + ps += 16; + pods += 16; + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + { + __m256 _absmax0 = _mm512_castps512_ps256(_absmax_avx512); + __m256 _absmax1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_absmax_avx512), 1)); + _absmax_avx = _mm256_max_ps(_absmax_avx, _absmax0); + _absmax_avx = _mm256_max_ps(_absmax_avx, _absmax1); + } +#endif // __AVX512F__ + + __m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx); + __m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx); + _mm256_store_ps(ps, _scale); + _mm256_store_ps(pods, _out_descale); + ps += 8; + pods += 8; + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + { + __m256 _absmax0 = _mm512_castps512_ps256(_absmax_avx512); + __m256 _absmax1 = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(_absmax_avx512), 1)); + _absmax_avx = _mm256_max_ps(_absmax_avx, _absmax0); + _absmax_avx = _mm256_max_ps(_absmax_avx, _absmax1); + } +#endif // __AVX512F__ + { + __m128 _absmax0 = _mm256_castps256_ps128(_absmax_avx); + __m128 _absmax1 = _mm256_extractf128_ps(_absmax_avx, 1); + _absmax = _mm_max_ps(_absmax, _absmax0); + _absmax = _mm_max_ps(_absmax, _absmax1); + } +#endif // __AVX__ + + __m128 _scale = _mm_div_ps(_v127, _absmax); + __m128 _out_descale = _mm_div_ps(_absmax, _v127_B_scale); + _mm_store_ps(ps, _scale); + _mm_store_ps(pods, _out_descale); + ps += 4; + pods += 4; + } +#endif // __SSE2__ + if (elempack == 1) + { +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + absmax = std::max(absmax, _mm512_comp_reduce_max_ps(_absmax_avx512)); +#endif // __AVX512F__ + absmax = std::max(absmax, _mm256_reduce_max_ps(_absmax_avx)); +#endif // __AVX__ + absmax = std::max(absmax, _mm_reduce_max_ps(_absmax)); +#endif // __SSE2__ + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + pack_A_tile_fp32_to_int8_avx512vnni(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + pack_A_tile_fp32_to_int8_avxvnniint8(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + pack_A_tile_fp32_to_int8_avxvnni(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + pack_A_tile_fp32_to_int8_avx2(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("pack_A_tile_fp32_to_int8 %d %d %d", max_ii, max_kk, elempack); + + signed char* pp = (signed char*)AT; + + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + __m512 _scales = _mm512_load_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ + __m512i _w_shift = _mm512_setzero_si512(); + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ + + if (elempack == 16) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + _p2 = _mm512_mul_ps(_p2, _scales); + _p3 = _mm512_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 64; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + // transpose16x2_epi8 + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += 32; + } + for (; kk < max_kk; kk++) + { + __m512 _p = _mm512_load_ps(p0); + + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 16; + } + } + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + A_hstep * 8); + __m512 _p3 = _mm512_loadu_ps(p0 + A_hstep * 8 + 16); + + __m512 _t0 = _mm512_shuffle_f32x4(_p0, _p2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t1 = _mm512_shuffle_f32x4(_p0, _p2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _t2 = _mm512_shuffle_f32x4(_p1, _p3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t3 = _mm512_shuffle_f32x4(_p1, _p3, _MM_SHUFFLE(3, 2, 3, 2)); + + _t0 = _mm512_mul_ps(_t0, _scales); + _t1 = _mm512_mul_ps(_t1, _scales); + _t2 = _mm512_mul_ps(_t2, _scales); + _t3 = _mm512_mul_ps(_t3, _scales); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + __m128i _pp2 = float2int8_avx512(_t2); + __m128i _pp3 = float2int8_avx512(_t3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 32; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + A_hstep * 8); + + __m512 _t0 = _mm512_shuffle_f32x4(_p0, _p1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t1 = _mm512_shuffle_f32x4(_p0, _p1, _MM_SHUFFLE(3, 2, 3, 2)); + + _t0 = _mm512_mul_ps(_t0, _scales); + _t1 = _mm512_mul_ps(_t1, _scales); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + + // transpose16x2_epi8 + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_store_si128((__m128i*)pp, _tt0); + _mm_store_si128((__m128i*)(pp + 16), _tt1); + + pp += 32; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + A_hstep * 8); + + __m512 _p = combine8x2_ps(_p0, _p1); + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 8; + } + } + if (elempack == 4) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + A_hstep * 4); + __m512 _p2 = _mm512_loadu_ps(p0 + A_hstep * 8); + __m512 _p3 = _mm512_loadu_ps(p0 + A_hstep * 12); + + __m512 _t0 = _mm512_shuffle_f32x4(_p0, _p1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t1 = _mm512_shuffle_f32x4(_p0, _p1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _t2 = _mm512_shuffle_f32x4(_p2, _p3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t3 = _mm512_shuffle_f32x4(_p2, _p3, _MM_SHUFFLE(3, 2, 3, 2)); + + _p0 = _mm512_shuffle_f32x4(_t0, _t2, _MM_SHUFFLE(2, 0, 2, 0)); + _p1 = _mm512_shuffle_f32x4(_t0, _t2, _MM_SHUFFLE(3, 1, 3, 1)); + _p2 = _mm512_shuffle_f32x4(_t1, _t3, _MM_SHUFFLE(2, 0, 2, 0)); + _p3 = _mm512_shuffle_f32x4(_t1, _t3, _MM_SHUFFLE(3, 1, 3, 1)); + + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + _p2 = _mm512_mul_ps(_p2, _scales); + _p3 = _mm512_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 16; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep * 4); + __m256 _p2 = _mm256_loadu_ps(p0 + A_hstep * 8); + __m256 _p3 = _mm256_loadu_ps(p0 + A_hstep * 12); + + __m512 _p01 = combine8x2_ps(_p0, _p1); + __m512 _p23 = combine8x2_ps(_p2, _p3); + + __m512 _t0 = _mm512_shuffle_f32x4(_p01, _p23, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _t1 = _mm512_shuffle_f32x4(_p01, _p23, _MM_SHUFFLE(3, 1, 3, 1)); + + _t0 = _mm512_mul_ps(_t0, _scales); + _t1 = _mm512_mul_ps(_t1, _scales); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + + // transpose16x2_epi8 + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_store_si128((__m128i*)pp, _tt0); + _mm_store_si128((__m128i*)(pp + 16), _tt1); + + pp += 32; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + A_hstep * 4); + __m128 _p2 = _mm_load_ps(p0 + A_hstep * 8); + __m128 _p3 = _mm_load_ps(p0 + A_hstep * 12); + + __m512 _p = combine4x4_ps(_p0, _p1, _p2, _p3); + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + A_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + A_hstep * 3); + __m128 _p4 = _mm_loadu_ps(p0 + A_hstep * 4); + __m128 _p5 = _mm_loadu_ps(p0 + A_hstep * 5); + __m128 _p6 = _mm_loadu_ps(p0 + A_hstep * 6); + __m128 _p7 = _mm_loadu_ps(p0 + A_hstep * 7); + __m128 _p8 = _mm_loadu_ps(p0 + A_hstep * 8); + __m128 _p9 = _mm_loadu_ps(p0 + A_hstep * 9); + __m128 _pa = _mm_loadu_ps(p0 + A_hstep * 10); + __m128 _pb = _mm_loadu_ps(p0 + A_hstep * 11); + __m128 _pc = _mm_loadu_ps(p0 + A_hstep * 12); + __m128 _pd = _mm_loadu_ps(p0 + A_hstep * 13); + __m128 _pe = _mm_loadu_ps(p0 + A_hstep * 14); + __m128 _pf = _mm_loadu_ps(p0 + A_hstep * 15); + + __m512 _t0 = combine4x4_ps(_p0, _p4, _p8, _pc); + __m512 _t1 = combine4x4_ps(_p1, _p5, _p9, _pd); + __m512 _t2 = combine4x4_ps(_p2, _p6, _pa, _pe); + __m512 _t3 = combine4x4_ps(_p3, _p7, _pb, _pf); + + __m512 _t4 = _mm512_unpacklo_ps(_t0, _t1); + __m512 _t5 = _mm512_unpackhi_ps(_t0, _t1); + __m512 _t6 = _mm512_unpacklo_ps(_t2, _t3); + __m512 _t7 = _mm512_unpackhi_ps(_t2, _t3); + + _t0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_t4), _mm512_castps_pd(_t6))); + _t1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_t4), _mm512_castps_pd(_t6))); + _t2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_t5), _mm512_castps_pd(_t7))); + _t3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_t5), _mm512_castps_pd(_t7))); + + _t0 = _mm512_mul_ps(_t0, _scales); + _t1 = _mm512_mul_ps(_t1, _scales); + _t2 = _mm512_mul_ps(_t2, _scales); + _t3 = _mm512_mul_ps(_t3, _scales); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + __m128i _pp2 = float2int8_avx512(_t2); + __m128i _pp3 = float2int8_avx512(_t3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 4; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A_hstep)); + + __m512 _p0 = _mm512_i32gather_ps(_vindex, p0, sizeof(float)); + __m512 _p1 = _mm512_i32gather_ps(_vindex, p0 + 1, sizeof(float)); + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A_hstep)); + + __m512 _p = _mm512_i32gather_ps(_vindex, p0, sizeof(float)); + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0++; + } + } + } +#endif // __AVX512F__ +#if !__AVX2__ + signed char* pp1 = pp + max_kk * 4; +#endif + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + __m256 _scales = _mm256_load_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m256i _w_shift = _mm256_setzero_si256(); + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + _p2 = _mm256_mul_ps(_p2, _scales); + _p3 = _mm256_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm256_storeu_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += 32; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm256_storeu_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + +#if __AVX2__ +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; +#else + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); + pp += 8; + pp1 += 8; +#endif + p0 += 16; + } + for (; kk < max_kk; kk++) + { + __m256 _p = _mm256_load_ps(p0); + + _p = _mm256_mul_ps(_p, _scales); + + int64_t v = float2int8_avx(_p); + +#if __AVX2__ + *(int64_t*)pp = v; + pp += 8; +#else + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); + pp += 4; + pp1 += 4; +#endif + p0 += 8; + } + } + if (elempack == 4) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + A_hstep * 4); + __m256 _p3 = _mm256_loadu_ps(p0 + A_hstep * 4 + 8); + + __m256 _t0 = _mm256_permute2f128_ps(_p0, _p2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _t1 = _mm256_permute2f128_ps(_p0, _p2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _t2 = _mm256_permute2f128_ps(_p1, _p3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _t3 = _mm256_permute2f128_ps(_p1, _p3, _MM_SHUFFLE(0, 3, 0, 1)); + + _t0 = _mm256_mul_ps(_t0, _scales); + _t1 = _mm256_mul_ps(_t1, _scales); + _t2 = _mm256_mul_ps(_t2, _scales); + _t3 = _mm256_mul_ps(_t3, _scales); + + __m128i _pp0 = float2int8_avx(_t0, _t2); + __m128i _pp1 = float2int8_avx(_t1, _t3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm256_storeu_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += 16; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm256_storeu_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep * 4); + + __m256 _t0 = _mm256_permute2f128_ps(_p0, _p1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _t1 = _mm256_permute2f128_ps(_p0, _p1, _MM_SHUFFLE(0, 3, 0, 1)); + + _t0 = _mm256_mul_ps(_t0, _scales); + _t1 = _mm256_mul_ps(_t1, _scales); + + __m128i _pp = float2int8_avx(_t0, _t1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + +#if __AVX2__ +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; +#else + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); + pp += 8; + pp1 += 8; +#endif + p0 += 8; + } + for (; kk < max_kk; kk++) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + A_hstep * 4); + + __m256 _p = combine4x2_ps(_p0, _p1); + _p = _mm256_mul_ps(_p, _scales); + + int64_t v = float2int8_avx(_p); + +#if __AVX2__ + *(int64_t*)pp = v; + pp += 8; +#else + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); + pp += 4; + pp1 += 4; +#endif + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + A_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + A_hstep * 3); + __m128 _p4 = _mm_loadu_ps(p0 + A_hstep * 4); + __m128 _p5 = _mm_loadu_ps(p0 + A_hstep * 5); + __m128 _p6 = _mm_loadu_ps(p0 + A_hstep * 6); + __m128 _p7 = _mm_loadu_ps(p0 + A_hstep * 7); + + __m256 _t0 = combine4x2_ps(_p0, _p4); + __m256 _t1 = combine4x2_ps(_p1, _p5); + __m256 _t2 = combine4x2_ps(_p2, _p6); + __m256 _t3 = combine4x2_ps(_p3, _p7); + + __m256 _t4 = _mm256_unpacklo_ps(_t0, _t1); + __m256 _t5 = _mm256_unpackhi_ps(_t0, _t1); + __m256 _t6 = _mm256_unpacklo_ps(_t2, _t3); + __m256 _t7 = _mm256_unpackhi_ps(_t2, _t3); + + _t0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_t4), _mm256_castps_pd(_t6))); + _t1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_t4), _mm256_castps_pd(_t6))); + _t2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_t5), _mm256_castps_pd(_t7))); + _t3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_t5), _mm256_castps_pd(_t7))); + + _t0 = _mm256_mul_ps(_t0, _scales); + _t1 = _mm256_mul_ps(_t1, _scales); + _t2 = _mm256_mul_ps(_t2, _scales); + _t3 = _mm256_mul_ps(_t3, _scales); + + __m128i _pp0 = float2int8_avx(_t0, _t2); + __m128i _pp1 = float2int8_avx(_t1, _t3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm256_storeu_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm256_storeu_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(A_hstep)); + + __m256 _p0 = _mm256_i32gather_ps(p0, _vindex, sizeof(float)); + __m256 _p1 = _mm256_i32gather_ps(p0 + 1, _vindex, sizeof(float)); +#else + __m256 _p0 = _mm256_setr_ps(p0[0], p0[A_hstep], p0[A_hstep * 2], p0[A_hstep * 3], p0[A_hstep * 4], p0[A_hstep * 5], p0[A_hstep * 6], p0[A_hstep * 7]); + __m256 _p1 = _mm256_setr_ps(p0[1], p0[A_hstep + 1], p0[A_hstep * 2 + 1], p0[A_hstep * 3 + 1], p0[A_hstep * 4 + 1], p0[A_hstep * 5 + 1], p0[A_hstep * 6 + 1], p0[A_hstep * 7 + 1]); +#endif + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + +#if __AVX2__ +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; +#else + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); + pp += 8; + pp1 += 8; +#endif + p0 += 2; + } + for (; kk < max_kk; kk++) + { +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(A_hstep)); + + __m256 _p = _mm256_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m256 _p = _mm256_setr_ps(p0[0], p0[A_hstep], p0[A_hstep * 2], p0[A_hstep * 3], p0[A_hstep * 4], p0[A_hstep * 5], p0[A_hstep * 6], p0[A_hstep * 7]); +#endif + + _p = _mm256_mul_ps(_p, _scales); + + int64_t v = float2int8_avx(_p); + +#if __AVX2__ + *(int64_t*)pp = v; + pp += 8; +#else + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); + pp += 4; + pp1 += 4; +#endif + p0++; + } + } + +#if !__AVX2__ + pp = pp1; + pp1 = pp + max_kk * 4; +#endif + } +#endif // __AVX__ + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + __m128 _scales = _mm_load_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m128i _w_shift = _mm_setzero_si128(); + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + + if (elempack == 4) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + __m128 _p2 = _mm_load_ps(p0 + 8); + __m128 _p3 = _mm_load_ps(p0 + 12); + + _p0 = _mm_mul_ps(_p0, _scales); + _p1 = _mm_mul_ps(_p1, _scales); + _p2 = _mm_mul_ps(_p2, _scales); + _p3 = _mm_mul_ps(_p3, _scales); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); + + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#if !__AVXVNNIINT8__ + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 16; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm_storeu_si128((__m128i*)pp, _w_shift); + pp += 16; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + _p0 = _mm_mul_ps(_p0, _scales); + _p1 = _mm_mul_ps(_p1, _scales); + __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + __m128 _p = _mm_load_ps(p0); + _p = _mm_mul_ps(_p, _scales); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + A_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + A_hstep * 3); + + _MM_TRANSPOSE4_PS(_p0, _p1, _p2, _p3); + + _p0 = _mm_mul_ps(_p0, _scales); + _p1 = _mm_mul_ps(_p1, _scales); + _p2 = _mm_mul_ps(_p2, _scales); + _p3 = _mm_mul_ps(_p3, _scales); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); + + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#if !__AVXVNNIINT8__ + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm_storeu_si128((__m128i*)pp, _w_shift); + pp += 16; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(A_hstep)); + + __m128 _p0 = _mm_i32gather_ps(p0, _vindex, sizeof(float)); + __m128 _p1 = _mm_i32gather_ps(p0 + 1, _vindex, sizeof(float)); +#else + __m128 _p0 = _mm_setr_ps(p0[0], p0[A_hstep], p0[A_hstep * 2], p0[A_hstep * 3]); + __m128 _p1 = _mm_setr_ps(p0[1], p0[A_hstep + 1], p0[A_hstep * 2 + 1], p0[A_hstep * 3 + 1]); +#endif + _p0 = _mm_mul_ps(_p0, _scales); + _p1 = _mm_mul_ps(_p1, _scales); + __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(A_hstep)); + + __m128 _p = _mm_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m128 _p = _mm_setr_ps(p0[0], p0[A_hstep], p0[A_hstep * 2], p0[A_hstep * 3]); +#endif + _p = _mm_mul_ps(_p, _scales); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0++; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; +#if __SSE2__ + __m128 _scales0 = _mm_set1_ps(scale0); + __m128 _scales1 = _mm_set1_ps(scale1); + __m128 _scales0011 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_scales0), _mm_castps_pd(_scales1))); +#endif // __SSE2__ + + // if (elempack == 1) + { + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + int w_shift0 = 0; + int w_shift1 = 0; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep); + _p0 = _mm_mul_ps(_p0, _scales0); + _p1 = _mm_mul_ps(_p1, _scales1); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; +#if !__AVXVNNIINT8__ + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ + pp += 8; + p0 += 4; + } +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 1 < max_kk; kk += 2) + { + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep))); + __m128 _p = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + _p = _mm_mul_ps(_p, _scales0011); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0 += 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep] * scale1); + pp += 2; + p0++; + } + } + } + for (; ii < max_ii; ii += 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + const float scale = scales[i + ii]; +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); +#endif // __SSE2__ + + // if (elempack == 1) + { + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + int w_shift = 0; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p = _mm_loadu_ps(p0); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp += 4; + p0 += 4; + } +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep * elempack : A.w * elempack; + const int K = A.dims == 3 ? A.c : A.h; + + // NCNN_LOGE("transpose_compute_A_tile_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + + float* ps = (float*)scales + i; + float* pods = (float*)out_descales + i; + + const int max_ii_unpacked = max_ii * elempack; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _v127_avx512 = _mm512_set1_ps(127.f); + __m512 _v127_B_scale_avx512 = _mm512_set1_ps(v127_B_scale); +#endif // __AVX512F__ + __m256 _v127_avx = _mm256_set1_ps(127.f); + __m256 _v127_B_scale_avx = _mm256_set1_ps(v127_B_scale); +#endif // __AVX__ + __m128 _v127 = _mm_set1_ps(127.f); + __m128 _v127_B_scale = _mm_set1_ps(v127_B_scale); +#endif // __SSE2__ + + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 63 < max_ii_unpacked; ii += 64) + { + const float* ptr = (const float*)A + i * elempack + ii; + + __m512 _absmax0_avx512 = _mm512_setzero_ps(); + __m512 _absmax1_avx512 = _mm512_setzero_ps(); + __m512 _absmax2_avx512 = _mm512_setzero_ps(); + __m512 _absmax3_avx512 = _mm512_setzero_ps(); + + for (int kk = 0; kk < K; kk++) + { + __m512 _p0 = _mm512_loadu_ps(ptr); + __m512 _p1 = _mm512_loadu_ps(ptr + 16); + __m512 _p2 = _mm512_loadu_ps(ptr + 32); + __m512 _p3 = _mm512_loadu_ps(ptr + 48); + _absmax0_avx512 = _mm512_max_ps(_absmax0_avx512, abs512_ps(_p0)); + _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, abs512_ps(_p1)); + _absmax2_avx512 = _mm512_max_ps(_absmax2_avx512, abs512_ps(_p2)); + _absmax3_avx512 = _mm512_max_ps(_absmax3_avx512, abs512_ps(_p3)); + ptr += A_hstep; + } + + if (elempack == 16) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_absmax0_avx512, _absmax2_avx512); + __m512 _tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax2_avx512); + __m512 _tmp2 = _mm512_unpacklo_ps(_absmax1_avx512, _absmax3_avx512); + __m512 _tmp3 = _mm512_unpackhi_ps(_absmax1_avx512, _absmax3_avx512); + _absmax0_avx512 = _mm512_max_ps(_tmp0, _tmp1); + _absmax1_avx512 = _mm512_max_ps(_tmp2, _tmp3); + _tmp0 = _mm512_unpacklo_ps(_absmax0_avx512, _absmax1_avx512); + _tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax1_avx512); + __m512 _absmax_avx512 = _mm512_max_ps(_tmp0, _tmp1); + __m256 _absmax0_avx = _mm512_extractf32x8_ps(_absmax_avx512, 0); + __m256 _absmax1_avx = _mm512_extractf32x8_ps(_absmax_avx512, 1); + __m256 _absmax_avx = _mm256_max_ps(_absmax0_avx, _absmax1_avx); + __m128 _absmax0 = _mm256_extractf128_ps(_absmax_avx, 0); + __m128 _absmax1 = _mm256_extractf128_ps(_absmax_avx, 1); + __m128 _absmax = _mm_max_ps(_absmax0, _absmax1); + __m128 _scale0 = _mm_div_ps(_v127, _absmax); + __m128 _out_descale0 = _mm_div_ps(_absmax, _v127_B_scale); + _mm_store_ps(ps, _scale0); + _mm_store_ps(pods, _out_descale0); + ps += 4; + pods += 4; + } + if (elempack == 8) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_absmax0_avx512, _absmax2_avx512); + __m512 _tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax2_avx512); + __m512 _tmp2 = _mm512_unpacklo_ps(_absmax1_avx512, _absmax3_avx512); + __m512 _tmp3 = _mm512_unpackhi_ps(_absmax1_avx512, _absmax3_avx512); + _absmax0_avx512 = _mm512_max_ps(_tmp0, _tmp1); + _absmax1_avx512 = _mm512_max_ps(_tmp2, _tmp3); + _tmp0 = _mm512_unpacklo_ps(_absmax0_avx512, _absmax1_avx512); + _tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax1_avx512); + __m512 _absmax_avx512 = _mm512_max_ps(_tmp0, _tmp1); + _absmax_avx512 = _mm512_permutexvar_ps(_mm512_setr_epi32(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15), _absmax_avx512); + __m256 _absmax0_avx = _mm512_extractf32x8_ps(_absmax_avx512, 0); + __m256 _absmax1_avx = _mm512_extractf32x8_ps(_absmax_avx512, 1); + __m256 _absmax_avx = _mm256_max_ps(_absmax0_avx, _absmax1_avx); + __m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx); + __m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx); + _mm256_store_ps(ps, _scale); + _mm256_store_ps(pods, _out_descale); + ps += 8; + pods += 8; + } + if (elempack == 4) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_absmax0_avx512, _absmax2_avx512); + __m512 _tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax2_avx512); + __m512 _tmp2 = _mm512_unpacklo_ps(_absmax1_avx512, _absmax3_avx512); + __m512 _tmp3 = _mm512_unpackhi_ps(_absmax1_avx512, _absmax3_avx512); + _absmax0_avx512 = _mm512_max_ps(_tmp0, _tmp1); + _absmax1_avx512 = _mm512_max_ps(_tmp2, _tmp3); + _tmp0 = _mm512_unpacklo_ps(_absmax0_avx512, _absmax1_avx512); + _tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax1_avx512); + __m512 _absmax_avx512 = _mm512_max_ps(_tmp0, _tmp1); + _absmax_avx512 = _mm512_permutexvar_ps(_mm512_setr_epi32(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15), _absmax_avx512); + __m512 _scale0 = _mm512_div_ps(_v127_avx512, _absmax_avx512); + __m512 _out_descale0 = _mm512_div_ps(_absmax_avx512, _v127_B_scale_avx512); + _mm512_store_ps(ps, _scale0); + _mm512_store_ps(pods, _out_descale0); + ps += 16; + pods += 16; + } + if (elempack == 1) + { + __m512 _scale0 = _mm512_div_ps(_v127_avx512, _absmax0_avx512); + __m512 _scale1 = _mm512_div_ps(_v127_avx512, _absmax1_avx512); + __m512 _scale2 = _mm512_div_ps(_v127_avx512, _absmax2_avx512); + __m512 _scale3 = _mm512_div_ps(_v127_avx512, _absmax3_avx512); + __m512 _out_descale0 = _mm512_div_ps(_absmax0_avx512, _v127_B_scale_avx512); + __m512 _out_descale1 = _mm512_div_ps(_absmax1_avx512, _v127_B_scale_avx512); + __m512 _out_descale2 = _mm512_div_ps(_absmax2_avx512, _v127_B_scale_avx512); + __m512 _out_descale3 = _mm512_div_ps(_absmax3_avx512, _v127_B_scale_avx512); + _mm512_store_ps(ps, _scale0); + _mm512_store_ps(ps + 16, _scale1); + _mm512_store_ps(ps + 32, _scale2); + _mm512_store_ps(ps + 48, _scale3); + _mm512_store_ps(pods, _out_descale0); + _mm512_store_ps(pods + 16, _out_descale1); + _mm512_store_ps(pods + 32, _out_descale2); + _mm512_store_ps(pods + 48, _out_descale3); + ps += 64; + pods += 64; + } + } +#endif // __AVX512F__ + for (; ii + 31 < max_ii_unpacked; ii += 32) + { + const float* ptr = (const float*)A + i * elempack + ii; + +#if __AVX512F__ + __m512 _absmax0_avx512 = _mm512_setzero_ps(); + __m512 _absmax1_avx512 = _mm512_setzero_ps(); + __m512 _absmax2_avx512 = _mm512_setzero_ps(); + __m512 _absmax3_avx512 = _mm512_setzero_ps(); +#else + __m256 _absmax0_avx = _mm256_setzero_ps(); + __m256 _absmax1_avx = _mm256_setzero_ps(); + __m256 _absmax2_avx = _mm256_setzero_ps(); + __m256 _absmax3_avx = _mm256_setzero_ps(); +#endif + + int kk = 0; +#if __AVX512F__ + for (; kk + 1 < K; kk += 2) + { + __m512 _p0 = _mm512_loadu_ps(ptr); + __m512 _p1 = _mm512_loadu_ps(ptr + 16); + __m512 _p2 = _mm512_loadu_ps(ptr + A_hstep); + __m512 _p3 = _mm512_loadu_ps(ptr + A_hstep + 16); + _absmax0_avx512 = _mm512_max_ps(_absmax0_avx512, abs512_ps(_p0)); + _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, abs512_ps(_p1)); + _absmax2_avx512 = _mm512_max_ps(_absmax2_avx512, abs512_ps(_p2)); + _absmax3_avx512 = _mm512_max_ps(_absmax3_avx512, abs512_ps(_p3)); + ptr += A_hstep * 2; + } + _absmax0_avx512 = _mm512_max_ps(_absmax0_avx512, _absmax2_avx512); + _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, _absmax3_avx512); +#endif // __AVX512F__ + for (; kk < K; kk++) + { +#if __AVX512F__ + __m512 _p0 = _mm512_loadu_ps(ptr); + __m512 _p1 = _mm512_loadu_ps(ptr + 16); + _absmax0_avx512 = _mm512_max_ps(_absmax0_avx512, abs512_ps(_p0)); + _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, abs512_ps(_p1)); +#else + __m256 _p0 = _mm256_loadu_ps(ptr); + __m256 _p1 = _mm256_loadu_ps(ptr + 8); + __m256 _p2 = _mm256_loadu_ps(ptr + 16); + __m256 _p3 = _mm256_loadu_ps(ptr + 24); + _absmax0_avx = _mm256_max_ps(_absmax0_avx, abs256_ps(_p0)); + _absmax1_avx = _mm256_max_ps(_absmax1_avx, abs256_ps(_p1)); + _absmax2_avx = _mm256_max_ps(_absmax2_avx, abs256_ps(_p2)); + _absmax3_avx = _mm256_max_ps(_absmax3_avx, abs256_ps(_p3)); +#endif + ptr += A_hstep; + } + +#if __AVX512F__ + if (elempack == 16) + { + float absmax0 = _mm512_reduce_max_ps(_absmax0_avx512); + float absmax1 = _mm512_reduce_max_ps(_absmax1_avx512); + ps[0] = 127.f / absmax0; + ps[1] = 127.f / absmax1; + pods[0] = absmax0 / v127_B_scale; + pods[1] = absmax1 / v127_B_scale; + ps += 2; + pods += 2; + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + __m512 _tmp0 = _mm512_unpacklo_ps(_absmax0_avx512, _absmax1_avx512); + __m512 _tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax1_avx512); + _tmp0 = _mm512_max_ps(_tmp0, _tmp1); + __m256 _absmax0_avx = _mm512_extractf32x8_ps(_tmp0, 0); + __m256 _absmax1_avx = _mm512_extractf32x8_ps(_tmp0, 1); +#else + __m256 _tmp0 = _mm256_unpacklo_ps(_absmax0_avx, _absmax2_avx); + __m256 _tmp1 = _mm256_unpackhi_ps(_absmax0_avx, _absmax2_avx); + __m256 _tmp2 = _mm256_unpacklo_ps(_absmax1_avx, _absmax3_avx); + __m256 _tmp3 = _mm256_unpackhi_ps(_absmax1_avx, _absmax3_avx); + _absmax0_avx = _mm256_max_ps(_tmp0, _tmp1); + _absmax1_avx = _mm256_max_ps(_tmp2, _tmp3); +#endif + __m256 _t0 = _mm256_unpacklo_ps(_absmax0_avx, _absmax1_avx); + __m256 _t1 = _mm256_unpackhi_ps(_absmax0_avx, _absmax1_avx); + _t0 = _mm256_max_ps(_t0, _t1); + __m128 _absmax0 = _mm256_extractf128_ps(_t0, 0); + __m128 _absmax1 = _mm256_extractf128_ps(_t0, 1); + __m128 _absmax = _mm_max_ps(_absmax0, _absmax1); + __m128 _scale0 = _mm_div_ps(_v127, _absmax); + __m128 _out_descale0 = _mm_div_ps(_absmax, _v127_B_scale); + _mm_store_ps(ps, _scale0); + _mm_store_ps(pods, _out_descale0); + ps += 4; + pods += 4; + } + if (elempack == 4) + { +#if __AVX512F__ + __m512 _tmp0 = _mm512_unpacklo_ps(_absmax0_avx512, _absmax1_avx512); + __m512 _tmp1 = _mm512_unpackhi_ps(_absmax0_avx512, _absmax1_avx512); + _tmp0 = _mm512_max_ps(_tmp0, _tmp1); + __m256 _absmax0_avx = _mm512_extractf32x8_ps(_tmp0, 0); + __m256 _absmax1_avx = _mm512_extractf32x8_ps(_tmp0, 1); +#else + __m256 _tmp0 = _mm256_unpacklo_ps(_absmax0_avx, _absmax2_avx); + __m256 _tmp1 = _mm256_unpackhi_ps(_absmax0_avx, _absmax2_avx); + __m256 _tmp2 = _mm256_unpacklo_ps(_absmax1_avx, _absmax3_avx); + __m256 _tmp3 = _mm256_unpackhi_ps(_absmax1_avx, _absmax3_avx); + _absmax0_avx = _mm256_max_ps(_tmp0, _tmp1); + _absmax1_avx = _mm256_max_ps(_tmp2, _tmp3); +#endif + __m256 _t0 = _mm256_unpacklo_ps(_absmax0_avx, _absmax1_avx); + __m256 _t1 = _mm256_unpackhi_ps(_absmax0_avx, _absmax1_avx); + __m256 _absmax_avx = _mm256_max_ps(_t0, _t1); + __m128 _tt0 = _mm256_extractf128_ps(_absmax_avx, 0); + __m128 _tt1 = _mm256_extractf128_ps(_absmax_avx, 1); + __m128 _absmax0 = _mm_unpacklo_ps(_tt0, _tt1); + __m128 _absmax1 = _mm_unpackhi_ps(_tt0, _tt1); + _absmax_avx = combine4x2_ps(_absmax0, _absmax1); + __m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx); + __m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx); + _mm256_store_ps(ps, _scale); + _mm256_store_ps(pods, _out_descale); + ps += 8; + pods += 8; + } + if (elempack == 1) + { +#if __AVX512F__ + __m512 _scale0 = _mm512_div_ps(_v127_avx512, _absmax0_avx512); + __m512 _scale1 = _mm512_div_ps(_v127_avx512, _absmax1_avx512); + __m512 _out_descale0 = _mm512_div_ps(_absmax0_avx512, _v127_B_scale_avx512); + __m512 _out_descale1 = _mm512_div_ps(_absmax1_avx512, _v127_B_scale_avx512); + _mm512_store_ps(ps, _scale0); + _mm512_store_ps(ps + 16, _scale1); + _mm512_store_ps(pods, _out_descale0); + _mm512_store_ps(pods + 16, _out_descale1); +#else + __m256 _scale0 = _mm256_div_ps(_v127_avx, _absmax0_avx); + __m256 _scale1 = _mm256_div_ps(_v127_avx, _absmax1_avx); + __m256 _scale2 = _mm256_div_ps(_v127_avx, _absmax2_avx); + __m256 _scale3 = _mm256_div_ps(_v127_avx, _absmax3_avx); + __m256 _out_descale0 = _mm256_div_ps(_absmax0_avx, _v127_B_scale_avx); + __m256 _out_descale1 = _mm256_div_ps(_absmax1_avx, _v127_B_scale_avx); + __m256 _out_descale2 = _mm256_div_ps(_absmax2_avx, _v127_B_scale_avx); + __m256 _out_descale3 = _mm256_div_ps(_absmax3_avx, _v127_B_scale_avx); + _mm256_store_ps(ps, _scale0); + _mm256_store_ps(ps + 8, _scale1); + _mm256_store_ps(ps + 16, _scale2); + _mm256_store_ps(ps + 24, _scale3); + _mm256_store_ps(pods, _out_descale0); + _mm256_store_ps(pods + 8, _out_descale1); + _mm256_store_ps(pods + 16, _out_descale2); + _mm256_store_ps(pods + 24, _out_descale3); +#endif + ps += 32; + pods += 32; + } + } +#endif // __AVX__ + for (; ii + 15 < max_ii_unpacked; ii += 16) + { + const float* ptr = (const float*)A + i * elempack + ii; + +#if __AVX512F__ + __m512 _absmax_avx512 = _mm512_setzero_ps(); + __m512 _absmax1_avx512 = _mm512_setzero_ps(); + __m512 _absmax2_avx512 = _mm512_setzero_ps(); + __m512 _absmax3_avx512 = _mm512_setzero_ps(); +#elif __AVX__ + __m256 _absmax0_avx = _mm256_setzero_ps(); + __m256 _absmax1_avx = _mm256_setzero_ps(); + __m256 _absmax2_avx = _mm256_setzero_ps(); + __m256 _absmax3_avx = _mm256_setzero_ps(); +#else + __m128 _absmax0 = _mm_setzero_ps(); + __m128 _absmax1 = _mm_setzero_ps(); + __m128 _absmax2 = _mm_setzero_ps(); + __m128 _absmax3 = _mm_setzero_ps(); +#endif + + int kk = 0; +#if __AVX__ +#if __AVX512F__ + for (; kk + 3 < K; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(ptr); + __m512 _p1 = _mm512_loadu_ps(ptr + A_hstep); + __m512 _p2 = _mm512_loadu_ps(ptr + A_hstep * 2); + __m512 _p3 = _mm512_loadu_ps(ptr + A_hstep * 3); + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p0)); + _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, abs512_ps(_p1)); + _absmax2_avx512 = _mm512_max_ps(_absmax2_avx512, abs512_ps(_p2)); + _absmax3_avx512 = _mm512_max_ps(_absmax3_avx512, abs512_ps(_p3)); + ptr += A_hstep * 4; + } + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, _absmax2_avx512); + _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, _absmax3_avx512); +#endif // __AVX512F__ + for (; kk + 1 < K; kk += 2) + { +#if __AVX512F__ + __m512 _p0 = _mm512_loadu_ps(ptr); + __m512 _p1 = _mm512_loadu_ps(ptr + A_hstep); + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p0)); + _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, abs512_ps(_p1)); +#else + __m256 _p0 = _mm256_loadu_ps(ptr); + __m256 _p1 = _mm256_loadu_ps(ptr + 8); + __m256 _p2 = _mm256_loadu_ps(ptr + A_hstep); + __m256 _p3 = _mm256_loadu_ps(ptr + A_hstep + 8); + _absmax0_avx = _mm256_max_ps(_absmax0_avx, abs256_ps(_p0)); + _absmax1_avx = _mm256_max_ps(_absmax1_avx, abs256_ps(_p1)); + _absmax2_avx = _mm256_max_ps(_absmax2_avx, abs256_ps(_p2)); + _absmax3_avx = _mm256_max_ps(_absmax3_avx, abs256_ps(_p3)); +#endif + ptr += A_hstep * 2; + } +#if __AVX512F__ + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, _absmax1_avx512); +#else + _absmax0_avx = _mm256_max_ps(_absmax0_avx, _absmax2_avx); + _absmax1_avx = _mm256_max_ps(_absmax1_avx, _absmax3_avx); +#endif +#endif // __AVX__ + for (; kk < K; kk++) + { +#if __AVX512F__ + __m512 _p = _mm512_loadu_ps(ptr); + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p)); +#elif __AVX__ + __m256 _p0 = _mm256_loadu_ps(ptr); + __m256 _p1 = _mm256_loadu_ps(ptr + 8); + _absmax0_avx = _mm256_max_ps(_absmax0_avx, abs256_ps(_p0)); + _absmax1_avx = _mm256_max_ps(_absmax1_avx, abs256_ps(_p1)); +#else + __m128 _p0 = _mm_loadu_ps(ptr); + __m128 _p1 = _mm_loadu_ps(ptr + 4); + __m128 _p2 = _mm_loadu_ps(ptr + 8); + __m128 _p3 = _mm_loadu_ps(ptr + 12); + _absmax0 = _mm_max_ps(_absmax0, abs_ps(_p0)); + _absmax1 = _mm_max_ps(_absmax1, abs_ps(_p1)); + _absmax2 = _mm_max_ps(_absmax2, abs_ps(_p2)); + _absmax3 = _mm_max_ps(_absmax3, abs_ps(_p3)); +#endif + ptr += A_hstep; + } + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + float absmax = _mm512_reduce_max_ps(_absmax_avx512); + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } +#endif // __AVX512F__ + if (elempack == 8) + { +#if __AVX512F__ + __m256 _absmax0_avx = _mm512_extractf32x8_ps(_absmax_avx512, 0); + __m256 _absmax1_avx = _mm512_extractf32x8_ps(_absmax_avx512, 1); +#endif + float absmax0 = _mm256_reduce_max_ps(_absmax0_avx); + float absmax1 = _mm256_reduce_max_ps(_absmax1_avx); + ps[0] = 127.f / absmax0; + ps[1] = 127.f / absmax1; + pods[0] = absmax0 / v127_B_scale; + pods[1] = absmax1 / v127_B_scale; + ps += 2; + pods += 2; + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ +#if __AVX512F__ + __m256 _absmax0_avx = _mm512_extractf32x8_ps(_absmax_avx512, 0); + __m256 _absmax1_avx = _mm512_extractf32x8_ps(_absmax_avx512, 1); +#endif + __m256 _tmp0 = _mm256_unpacklo_ps(_absmax0_avx, _absmax1_avx); + __m256 _tmp1 = _mm256_unpackhi_ps(_absmax0_avx, _absmax1_avx); + __m256 _absmax01_avx = _mm256_max_ps(_tmp0, _tmp1); + __m128 _absmax0 = _mm256_extractf128_ps(_absmax01_avx, 0); + __m128 _absmax1 = _mm256_extractf128_ps(_absmax01_avx, 1); +#else + __m128 _tmp0 = _mm_unpacklo_ps(_absmax0, _absmax2); + __m128 _tmp1 = _mm_unpackhi_ps(_absmax0, _absmax2); + __m128 _tmp2 = _mm_unpacklo_ps(_absmax1, _absmax3); + __m128 _tmp3 = _mm_unpackhi_ps(_absmax1, _absmax3); + _absmax0 = _mm_max_ps(_tmp0, _tmp1); + _absmax1 = _mm_max_ps(_tmp2, _tmp3); +#endif + __m128 _t0 = _mm_unpacklo_ps(_absmax0, _absmax1); + __m128 _t1 = _mm_unpackhi_ps(_absmax0, _absmax1); + __m128 _absmax = _mm_max_ps(_t0, _t1); + __m128 _scale0 = _mm_div_ps(_v127, _absmax); + __m128 _out_descale0 = _mm_div_ps(_absmax, _v127_B_scale); + _mm_store_ps(ps, _scale0); + _mm_store_ps(pods, _out_descale0); + ps += 4; + pods += 4; + } + if (elempack == 1) + { +#if __AVX512F__ + __m512 _scale = _mm512_div_ps(_v127_avx512, _absmax_avx512); + __m512 _out_descale = _mm512_div_ps(_absmax_avx512, _v127_B_scale_avx512); + _mm512_store_ps(ps, _scale); + _mm512_store_ps(pods, _out_descale); +#elif __AVX__ + __m256 _scale0 = _mm256_div_ps(_v127_avx, _absmax0_avx); + __m256 _scale1 = _mm256_div_ps(_v127_avx, _absmax1_avx); + __m256 _out_descale0 = _mm256_div_ps(_absmax0_avx, _v127_B_scale_avx); + __m256 _out_descale1 = _mm256_div_ps(_absmax1_avx, _v127_B_scale_avx); + _mm256_store_ps(ps, _scale0); + _mm256_store_ps(ps + 8, _scale1); + _mm256_store_ps(pods, _out_descale0); + _mm256_store_ps(pods + 8, _out_descale1); +#else + __m128 _scale0 = _mm_div_ps(_v127, _absmax0); + __m128 _scale1 = _mm_div_ps(_v127, _absmax1); + __m128 _scale2 = _mm_div_ps(_v127, _absmax2); + __m128 _scale3 = _mm_div_ps(_v127, _absmax3); + __m128 _out_descale0 = _mm_div_ps(_absmax0, _v127_B_scale); + __m128 _out_descale1 = _mm_div_ps(_absmax1, _v127_B_scale); + __m128 _out_descale2 = _mm_div_ps(_absmax2, _v127_B_scale); + __m128 _out_descale3 = _mm_div_ps(_absmax3, _v127_B_scale); + _mm_store_ps(ps, _scale0); + _mm_store_ps(ps + 4, _scale1); + _mm_store_ps(ps + 8, _scale2); + _mm_store_ps(ps + 12, _scale3); + _mm_store_ps(pods, _out_descale0); + _mm_store_ps(pods + 4, _out_descale1); + _mm_store_ps(pods + 8, _out_descale2); + _mm_store_ps(pods + 12, _out_descale3); +#endif + ps += 16; + pods += 16; + } + } + for (; ii + 7 < max_ii_unpacked; ii += 8) + { + const float* ptr = (const float*)A + i * elempack + ii; + +#if __AVX__ + __m256 _absmax_avx = _mm256_setzero_ps(); + __m256 _absmax1_avx = _mm256_setzero_ps(); + __m256 _absmax2_avx = _mm256_setzero_ps(); + __m256 _absmax3_avx = _mm256_setzero_ps(); +#else + __m128 _absmax0 = _mm_setzero_ps(); + __m128 _absmax1 = _mm_setzero_ps(); + __m128 _absmax2 = _mm_setzero_ps(); + __m128 _absmax3 = _mm_setzero_ps(); +#endif + + int kk = 0; +#if __AVX__ + for (; kk + 3 < K; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(ptr); + __m256 _p1 = _mm256_loadu_ps(ptr + A_hstep); + __m256 _p2 = _mm256_loadu_ps(ptr + A_hstep * 2); + __m256 _p3 = _mm256_loadu_ps(ptr + A_hstep * 3); + _absmax_avx = _mm256_max_ps(_absmax_avx, abs256_ps(_p0)); + _absmax1_avx = _mm256_max_ps(_absmax1_avx, abs256_ps(_p1)); + _absmax2_avx = _mm256_max_ps(_absmax2_avx, abs256_ps(_p2)); + _absmax3_avx = _mm256_max_ps(_absmax3_avx, abs256_ps(_p3)); + ptr += A_hstep * 4; + } + _absmax_avx = _mm256_max_ps(_absmax_avx, _absmax2_avx); + _absmax1_avx = _mm256_max_ps(_absmax1_avx, _absmax3_avx); +#endif // __AVX__ + for (; kk + 1 < K; kk += 2) + { +#if __AVX__ + __m256 _p0 = _mm256_loadu_ps(ptr); + __m256 _p1 = _mm256_loadu_ps(ptr + A_hstep); + _absmax_avx = _mm256_max_ps(_absmax_avx, abs256_ps(_p0)); + _absmax1_avx = _mm256_max_ps(_absmax1_avx, abs256_ps(_p1)); +#else + __m128 _p0 = _mm_loadu_ps(ptr); + __m128 _p1 = _mm_loadu_ps(ptr + 4); + __m128 _p2 = _mm_loadu_ps(ptr + A_hstep); + __m128 _p3 = _mm_loadu_ps(ptr + A_hstep + 4); + _absmax0 = _mm_max_ps(_absmax0, abs_ps(_p0)); + _absmax1 = _mm_max_ps(_absmax1, abs_ps(_p1)); + _absmax2 = _mm_max_ps(_absmax2, abs_ps(_p2)); + _absmax3 = _mm_max_ps(_absmax3, abs_ps(_p3)); +#endif + ptr += A_hstep * 2; + } +#if __AVX__ + _absmax_avx = _mm256_max_ps(_absmax_avx, _absmax1_avx); +#else + _absmax0 = _mm_max_ps(_absmax0, _absmax2); + _absmax1 = _mm_max_ps(_absmax1, _absmax3); +#endif + for (; kk < K; kk++) + { +#if __AVX__ + __m256 _p = _mm256_loadu_ps(ptr); + _absmax_avx = _mm256_max_ps(_absmax_avx, abs256_ps(_p)); +#else + __m128 _p0 = _mm_loadu_ps(ptr); + __m128 _p1 = _mm_loadu_ps(ptr + 4); + _absmax0 = _mm_max_ps(_absmax0, abs_ps(_p0)); + _absmax1 = _mm_max_ps(_absmax1, abs_ps(_p1)); +#endif + ptr += A_hstep; + } + +#if __AVX__ + if (elempack == 8) + { + float absmax = _mm256_reduce_max_ps(_absmax_avx); + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ + __m128 _absmax0 = _mm256_extractf128_ps(_absmax_avx, 0); + __m128 _absmax1 = _mm256_extractf128_ps(_absmax_avx, 1); +#endif + float absmax0 = _mm_reduce_max_ps(_absmax0); + float absmax1 = _mm_reduce_max_ps(_absmax1); + ps[0] = 127.f / absmax0; + ps[1] = 127.f / absmax1; + pods[0] = absmax0 / v127_B_scale; + pods[1] = absmax1 / v127_B_scale; + ps += 2; + pods += 2; + } + if (elempack == 1) + { +#if __AVX__ + __m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx); + __m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx); + _mm256_store_ps(ps, _scale); + _mm256_store_ps(pods, _out_descale); +#else + __m128 _scale0 = _mm_div_ps(_v127, _absmax0); + __m128 _scale1 = _mm_div_ps(_v127, _absmax1); + __m128 _out_descale0 = _mm_div_ps(_absmax0, _v127_B_scale); + __m128 _out_descale1 = _mm_div_ps(_absmax1, _v127_B_scale); + _mm_store_ps(ps, _scale0); + _mm_store_ps(ps + 4, _scale1); + _mm_store_ps(pods, _out_descale0); + _mm_store_ps(pods + 4, _out_descale1); +#endif + ps += 8; + pods += 8; + } + } +#endif // __SSE2__ + for (; ii + 3 < max_ii_unpacked; ii += 4) + { + const float* ptr = (const float*)A + i * elempack + ii; + +#if __SSE2__ + __m128 _absmax = _mm_setzero_ps(); + __m128 _absmax1 = _mm_setzero_ps(); + __m128 _absmax2 = _mm_setzero_ps(); + __m128 _absmax3 = _mm_setzero_ps(); +#else + float absmax0 = 0.f; + float absmax1 = 0.f; + float absmax2 = 0.f; + float absmax3 = 0.f; +#endif + + int kk = 0; +#if __SSE2__ + for (; kk + 3 < K; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(ptr); + __m128 _p1 = _mm_loadu_ps(ptr + A_hstep); + __m128 _p2 = _mm_loadu_ps(ptr + A_hstep * 2); + __m128 _p3 = _mm_loadu_ps(ptr + A_hstep * 3); + _absmax = _mm_max_ps(_absmax, abs_ps(_p0)); + _absmax1 = _mm_max_ps(_absmax1, abs_ps(_p1)); + _absmax2 = _mm_max_ps(_absmax2, abs_ps(_p2)); + _absmax3 = _mm_max_ps(_absmax3, abs_ps(_p3)); + ptr += A_hstep * 4; + } + _absmax = _mm_max_ps(_absmax, _absmax2); + _absmax1 = _mm_max_ps(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + __m128 _p0 = _mm_loadu_ps(ptr); + __m128 _p1 = _mm_loadu_ps(ptr + A_hstep); + _absmax = _mm_max_ps(_absmax, abs_ps(_p0)); + _absmax1 = _mm_max_ps(_absmax1, abs_ps(_p1)); + ptr += A_hstep * 2; + } + _absmax = _mm_max_ps(_absmax, _absmax1); +#endif // __SSE2__ + for (; kk < K; kk++) + { +#if __SSE2__ + __m128 _p = _mm_loadu_ps(ptr); + _absmax = _mm_max_ps(_absmax, abs_ps(_p)); +#else + absmax0 = std::max(absmax0, (float)fabsf(ptr[0])); + absmax1 = std::max(absmax1, (float)fabsf(ptr[1])); + absmax2 = std::max(absmax2, (float)fabsf(ptr[2])); + absmax3 = std::max(absmax3, (float)fabsf(ptr[3])); +#endif + ptr += A_hstep; + } + +#if __SSE2__ + if (elempack == 4) + { + float absmax = _mm_reduce_max_ps(_absmax); + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } +#endif // __SSE2__ + if (elempack == 1) + { +#if __SSE2__ + __m128 _scale = _mm_div_ps(_v127, _absmax); + __m128 _out_descale = _mm_div_ps(_absmax, _v127_B_scale); + _mm_store_ps(ps, _scale); + _mm_store_ps(pods, _out_descale); +#else + ps[0] = 127.f / absmax0; + ps[1] = 127.f / absmax1; + ps[2] = 127.f / absmax2; + ps[3] = 127.f / absmax3; + pods[0] = absmax0 / v127_B_scale; + pods[1] = absmax1 / v127_B_scale; + pods[2] = absmax2 / v127_B_scale; + pods[3] = absmax3 / v127_B_scale; +#endif + ps += 4; + pods += 4; + } + } + for (; ii + 1 < max_ii_unpacked; ii += 2) + { + const float* ptr = (const float*)A + i * elempack + ii; + + float absmax0 = 0.f; + float absmax1 = 0.f; + + int kk = 0; +#if __AVX512F__ + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A_hstep)); + + __m512 _absmax0_avx512 = _mm512_setzero_ps(); + __m512 _absmax1_avx512 = _mm512_setzero_ps(); + for (; kk + 15 < K; kk += 16) + { + __m512 _p0 = _mm512_i32gather_ps(_vindex, ptr, sizeof(float)); + __m512 _p1 = _mm512_i32gather_ps(_vindex, ptr + 1, sizeof(float)); + _absmax0_avx512 = _mm512_max_ps(_absmax0_avx512, abs512_ps(_p0)); + _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, abs512_ps(_p1)); + ptr += A_hstep * 16; + } + absmax0 = _mm512_comp_reduce_max_ps(_absmax0_avx512); + absmax1 = _mm512_comp_reduce_max_ps(_absmax1_avx512); +#endif // __AVX512F__ + for (; kk < K; kk++) + { + absmax0 = std::max(absmax0, (float)fabsf(ptr[0])); + absmax1 = std::max(absmax1, (float)fabsf(ptr[1])); + ptr += A_hstep; + } + + ps[0] = 127.f / absmax0; + ps[1] = 127.f / absmax1; + pods[0] = absmax0 / v127_B_scale; + pods[1] = absmax1 / v127_B_scale; + ps += 2; + pods += 2; + } + for (; ii < max_ii_unpacked; ii++) + { + const float* ptr = (const float*)A + i * elempack + ii; + + float absmax = 0.f; + + int kk = 0; +#if __AVX512F__ + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A_hstep)); + + __m512 _absmax_avx512 = _mm512_setzero_ps(); + for (; kk + 15 < K; kk += 16) + { + __m512 _p = _mm512_i32gather_ps(_vindex, ptr, sizeof(float)); + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p)); + ptr += A_hstep * 16; + } + absmax = _mm512_comp_reduce_max_ps(_absmax_avx512); +#endif // __AVX512F__ + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabsf(ptr[0])); + ptr += A_hstep; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } +} + +static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + transpose_pack_A_tile_fp32_to_int8_avx512vnni(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + transpose_pack_A_tile_fp32_to_int8_avxvnniint8(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + transpose_pack_A_tile_fp32_to_int8_avxvnni(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + transpose_pack_A_tile_fp32_to_int8_avx2(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("transpose_pack_A_tile_fp32_to_int8 %d %d", max_ii, elempack); + + signed char* pp = (signed char*)AT; + + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + __m512 _scales = _mm512_load_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ + __m512i _w_shift = _mm512_setzero_si512(); + __m512i _v127 = _mm512_set1_epi8(127); +#endif + + if (elempack == 16) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + __m512 _p4 = _mm512_load_ps(p0 + 64); + __m512 _p5 = _mm512_load_ps(p0 + 80); + __m512 _p6 = _mm512_load_ps(p0 + 96); + __m512 _p7 = _mm512_load_ps(p0 + 112); + __m512 _p8 = _mm512_load_ps(p0 + 128); + __m512 _p9 = _mm512_load_ps(p0 + 128 + 16); + __m512 _pa = _mm512_load_ps(p0 + 128 + 32); + __m512 _pb = _mm512_load_ps(p0 + 128 + 48); + __m512 _pc = _mm512_load_ps(p0 + 128 + 64); + __m512 _pd = _mm512_load_ps(p0 + 128 + 80); + __m512 _pe = _mm512_load_ps(p0 + 128 + 96); + __m512 _pf = _mm512_load_ps(p0 + 128 + 112); + + _p0 = _mm512_mul_ps(_p0, _mm512_set1_ps(scales[i + ii])); + _p1 = _mm512_mul_ps(_p1, _mm512_set1_ps(scales[i + ii + 1])); + _p2 = _mm512_mul_ps(_p2, _mm512_set1_ps(scales[i + ii + 2])); + _p3 = _mm512_mul_ps(_p3, _mm512_set1_ps(scales[i + ii + 3])); + _p4 = _mm512_mul_ps(_p4, _mm512_set1_ps(scales[i + ii + 4])); + _p5 = _mm512_mul_ps(_p5, _mm512_set1_ps(scales[i + ii + 5])); + _p6 = _mm512_mul_ps(_p6, _mm512_set1_ps(scales[i + ii + 6])); + _p7 = _mm512_mul_ps(_p7, _mm512_set1_ps(scales[i + ii + 7])); + _p8 = _mm512_mul_ps(_p8, _mm512_set1_ps(scales[i + ii + 8])); + _p9 = _mm512_mul_ps(_p9, _mm512_set1_ps(scales[i + ii + 9])); + _pa = _mm512_mul_ps(_pa, _mm512_set1_ps(scales[i + ii + 10])); + _pb = _mm512_mul_ps(_pb, _mm512_set1_ps(scales[i + ii + 11])); + _pc = _mm512_mul_ps(_pc, _mm512_set1_ps(scales[i + ii + 12])); + _pd = _mm512_mul_ps(_pd, _mm512_set1_ps(scales[i + ii + 13])); + _pe = _mm512_mul_ps(_pe, _mm512_set1_ps(scales[i + ii + 14])); + _pf = _mm512_mul_ps(_pf, _mm512_set1_ps(scales[i + ii + 15])); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + __m128i _pp8 = float2int8_avx512(_p8); + __m128i _pp9 = float2int8_avx512(_p9); + __m128i _ppa = float2int8_avx512(_pa); + __m128i _ppb = float2int8_avx512(_pb); + __m128i _ppc = float2int8_avx512(_pc); + __m128i _ppd = float2int8_avx512(_pd); + __m128i _ppe = float2int8_avx512(_pe); + __m128i _ppf = float2int8_avx512(_pf); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp4, _pp8, _ppc); + __m512i _t1 = combine4x4_epi32(_pp1, _pp5, _pp9, _ppd); + __m512i _t2 = combine4x4_epi32(_pp2, _pp6, _ppa, _ppe); + __m512i _t3 = combine4x4_epi32(_pp3, _pp7, _ppb, _ppf); + + __m512i _t4 = _mm512_unpacklo_epi32(_t0, _t1); + __m512i _t5 = _mm512_unpackhi_epi32(_t0, _t1); + __m512i _t6 = _mm512_unpacklo_epi32(_t2, _t3); + __m512i _t7 = _mm512_unpackhi_epi32(_t2, _t3); + _t0 = _mm512_unpacklo_epi64(_t4, _t6); + _t1 = _mm512_unpackhi_epi64(_t4, _t6); + _t2 = _mm512_unpacklo_epi64(_t5, _t7); + _t3 = _mm512_unpackhi_epi64(_t5, _t7); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _t0); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _t1); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _t2); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _t3); + + _mm512_store_si512((__m512i*)pp, _t0); + _mm512_store_si512((__m512i*)(pp + 64), _t1); + _mm512_store_si512((__m512i*)(pp + 128), _t2); + _mm512_store_si512((__m512i*)(pp + 192), _t3); + + pp += 256; + p0 += A_hstep * 16; + } + if (max_kk >= 4) + { + _mm512_store_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#else // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + __m512 _p4 = _mm512_load_ps(p0 + 64); + __m512 _p5 = _mm512_load_ps(p0 + 80); + __m512 _p6 = _mm512_load_ps(p0 + 96); + __m512 _p7 = _mm512_load_ps(p0 + 112); + __m512 _p8 = _mm512_load_ps(p0 + 128); + __m512 _p9 = _mm512_load_ps(p0 + 128 + 16); + __m512 _pa = _mm512_load_ps(p0 + 128 + 32); + __m512 _pb = _mm512_load_ps(p0 + 128 + 48); + __m512 _pc = _mm512_load_ps(p0 + 128 + 64); + __m512 _pd = _mm512_load_ps(p0 + 128 + 80); + __m512 _pe = _mm512_load_ps(p0 + 128 + 96); + __m512 _pf = _mm512_load_ps(p0 + 128 + 112); + + _p0 = _mm512_mul_ps(_p0, _mm512_set1_ps(scales[i + ii])); + _p1 = _mm512_mul_ps(_p1, _mm512_set1_ps(scales[i + ii + 1])); + _p2 = _mm512_mul_ps(_p2, _mm512_set1_ps(scales[i + ii + 2])); + _p3 = _mm512_mul_ps(_p3, _mm512_set1_ps(scales[i + ii + 3])); + _p4 = _mm512_mul_ps(_p4, _mm512_set1_ps(scales[i + ii + 4])); + _p5 = _mm512_mul_ps(_p5, _mm512_set1_ps(scales[i + ii + 5])); + _p6 = _mm512_mul_ps(_p6, _mm512_set1_ps(scales[i + ii + 6])); + _p7 = _mm512_mul_ps(_p7, _mm512_set1_ps(scales[i + ii + 7])); + _p8 = _mm512_mul_ps(_p8, _mm512_set1_ps(scales[i + ii + 8])); + _p9 = _mm512_mul_ps(_p9, _mm512_set1_ps(scales[i + ii + 9])); + _pa = _mm512_mul_ps(_pa, _mm512_set1_ps(scales[i + ii + 10])); + _pb = _mm512_mul_ps(_pb, _mm512_set1_ps(scales[i + ii + 11])); + _pc = _mm512_mul_ps(_pc, _mm512_set1_ps(scales[i + ii + 12])); + _pd = _mm512_mul_ps(_pd, _mm512_set1_ps(scales[i + ii + 13])); + _pe = _mm512_mul_ps(_pe, _mm512_set1_ps(scales[i + ii + 14])); + _pf = _mm512_mul_ps(_pf, _mm512_set1_ps(scales[i + ii + 15])); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + __m128i _pp8 = float2int8_avx512(_p8); + __m128i _pp9 = float2int8_avx512(_p9); + __m128i _ppa = float2int8_avx512(_pa); + __m128i _ppb = float2int8_avx512(_pb); + __m128i _ppc = float2int8_avx512(_pc); + __m128i _ppd = float2int8_avx512(_pd); + __m128i _ppe = float2int8_avx512(_pe); + __m128i _ppf = float2int8_avx512(_pf); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp4, _pp8, _ppc); + __m512i _t1 = combine4x4_epi32(_pp1, _pp5, _pp9, _ppd); + __m512i _t2 = combine4x4_epi32(_pp2, _pp6, _ppa, _ppe); + __m512i _t3 = combine4x4_epi32(_pp3, _pp7, _ppb, _ppf); + + __m512i _t4 = _mm512_unpacklo_epi16(_t0, _t1); + __m512i _t5 = _mm512_unpackhi_epi16(_t0, _t1); + __m512i _t6 = _mm512_unpacklo_epi16(_t2, _t3); + __m512i _t7 = _mm512_unpackhi_epi16(_t2, _t3); + + _t0 = _mm512_unpacklo_epi32(_t4, _t6); + _t1 = _mm512_unpackhi_epi32(_t4, _t6); + _t2 = _mm512_unpacklo_epi32(_t5, _t7); + _t3 = _mm512_unpackhi_epi32(_t5, _t7); + + _t0 = _mm512_permutex_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_permutex_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + _t2 = _mm512_permutex_epi64(_t2, _MM_SHUFFLE(3, 1, 2, 0)); + _t3 = _mm512_permutex_epi64(_t3, _MM_SHUFFLE(3, 1, 2, 0)); + _t0 = _mm512_shuffle_i32x4(_t0, _t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_shuffle_i32x4(_t1, _t1, _MM_SHUFFLE(3, 1, 2, 0)); + _t2 = _mm512_shuffle_i32x4(_t2, _t2, _MM_SHUFFLE(3, 1, 2, 0)); + _t3 = _mm512_shuffle_i32x4(_t3, _t3, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_store_si512((__m512i*)pp, _t0); + _mm512_store_si512((__m512i*)(pp + 64), _t1); + _mm512_store_si512((__m512i*)(pp + 128), _t2); + _mm512_store_si512((__m512i*)(pp + 192), _t3); + + pp += 256; + p0 += A_hstep * 16; + } +#endif // __AVX512VNNI__ + } + if (elempack == 8) + { + __m512 _scales0 = _scales; + __m512 _scales1 = _scales; + __m512 _scales2 = _scales; + __m512 _scales3 = _scales; + __m512 _scales4 = _scales; + __m512 _scales5 = _scales; + __m512 _scales6 = _scales; + __m512 _scales7 = _scales; + transpose16x8_ps(_scales0, _scales1, _scales2, _scales3, _scales4, _scales5, _scales6, _scales7); + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); + _p4 = _mm512_mul_ps(_p4, _scales4); + _p5 = _mm512_mul_ps(_p5, _scales5); + _p6 = _mm512_mul_ps(_p6, _scales6); + _p7 = _mm512_mul_ps(_p7, _scales7); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp2, _pp4, _pp6); + __m512i _t1 = combine4x4_epi32(_pp1, _pp3, _pp5, _pp7); + + __m512i _t2 = _mm512_unpacklo_epi32(_t0, _t1); + __m512i _t3 = _mm512_unpackhi_epi32(_t0, _t1); + __m512i _ppa = _mm512_unpacklo_epi32(_t2, _t3); + __m512i _ppb = _mm512_unpackhi_epi32(_t2, _t3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _ppa); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _ppb); + + _mm512_store_si512((__m512i*)pp, _ppa); + _mm512_store_si512((__m512i*)(pp + 64), _ppb); + + pp += 128; + p0 += A_hstep * 8; + } + if (max_kk >= 4) + { + _mm512_store_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#else // __AVX512VNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); + _p4 = _mm512_mul_ps(_p4, _scales4); + _p5 = _mm512_mul_ps(_p5, _scales5); + _p6 = _mm512_mul_ps(_p6, _scales6); + _p7 = _mm512_mul_ps(_p7, _scales7); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp2, _pp4, _pp6); + __m512i _t1 = combine4x4_epi32(_pp1, _pp3, _pp5, _pp7); + + __m512i _t2 = _mm512_unpacklo_epi16(_t0, _t1); + __m512i _t3 = _mm512_unpackhi_epi16(_t0, _t1); + _t0 = _mm512_unpacklo_epi16(_t2, _t3); + _t1 = _mm512_unpackhi_epi16(_t2, _t3); + _t0 = _mm512_permutex_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_permutex_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + __m512i _ppa = _mm512_shuffle_i32x4(_t0, _t0, _MM_SHUFFLE(3, 1, 2, 0)); + __m512i _ppb = _mm512_shuffle_i32x4(_t1, _t1, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_store_si512((__m512i*)pp, _ppa); + _mm512_store_si512((__m512i*)(pp + 64), _ppb); + + pp += 128; + p0 += A_hstep * 8; + } +#endif // __AVX512VNNI__ + } + if (elempack == 4) + { + __m512 _scales0 = _scales; + __m512 _scales1 = _scales; + __m512 _scales2 = _scales; + __m512 _scales3 = _scales; + transpose16x4_ps(_scales0, _scales1, _scales2, _scales3); + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_store_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + _mm512_store_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#else // __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m256i _pp02 = combine4x2_epi32(_pp0, _pp2); + __m256i _pp13 = combine4x2_epi32(_pp1, _pp3); + + __m256i _t0 = _mm256_unpacklo_epi16(_pp02, _pp13); + __m256i _t1 = _mm256_unpackhi_epi16(_pp02, _pp13); + __m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1); + _t0 = _mm256_unpacklo_epi16(_t2, _t3); + _t1 = _mm256_unpackhi_epi16(_t2, _t3); + + _mm256_store_si256((__m256i*)pp, _t0); + _mm256_store_si256((__m256i*)(pp + 32), _t1); + + pp += 64; + p0 += A_hstep * 4; + } +#endif // __AVX512VNNI__ + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + A_hstep); + __m512 _p2 = _mm512_loadu_ps(p0 + A_hstep * 2); + __m512 _p3 = _mm512_loadu_ps(p0 + A_hstep * 3); + + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + _p2 = _mm512_mul_ps(_p2, _scales); + _p3 = _mm512_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + A_hstep); + + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + // transpose16x2_epi8 + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m512 _p = _mm512_loadu_ps(p0); + + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += A_hstep; + } + } + } +#endif // __AVX512F__ +#if !__AVX2__ + signed char* pp1 = pp + max_kk * 4; +#endif + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + __m256 _scales = _mm256_load_ps((const float*)scales + i + ii); +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m256i _w_shift = _mm256_setzero_si256(); + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; +#if __AVX512VNNI__ + __m512i _w_shift_avx512 = _mm512_setzero_si512(); + __m512i _v127_avx512 = _mm512_set1_epi8(127); + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + __m512 _p4 = _mm512_load_ps(p0 + 64); + __m512 _p5 = _mm512_load_ps(p0 + 80); + __m512 _p6 = _mm512_load_ps(p0 + 96); + __m512 _p7 = _mm512_load_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _mm512_set1_ps(scales[i + ii])); + _p1 = _mm512_mul_ps(_p1, _mm512_set1_ps(scales[i + ii + 1])); + _p2 = _mm512_mul_ps(_p2, _mm512_set1_ps(scales[i + ii + 2])); + _p3 = _mm512_mul_ps(_p3, _mm512_set1_ps(scales[i + ii + 3])); + _p4 = _mm512_mul_ps(_p4, _mm512_set1_ps(scales[i + ii + 4])); + _p5 = _mm512_mul_ps(_p5, _mm512_set1_ps(scales[i + ii + 5])); + _p6 = _mm512_mul_ps(_p6, _mm512_set1_ps(scales[i + ii + 6])); + _p7 = _mm512_mul_ps(_p7, _mm512_set1_ps(scales[i + ii + 7])); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + transpose4x8_epi32(_pp0, _pp1, _pp2, _pp3, _pp4, _pp5, _pp6, _pp7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + __m512i _t1 = combine4x4_epi32(_pp4, _pp5, _pp6, _pp7); + + _w_shift_avx512 = _mm512_dpbusd_epi32(_w_shift_avx512, _v127_avx512, _t0); + _w_shift_avx512 = _mm512_dpbusd_epi32(_w_shift_avx512, _v127_avx512, _t1); + + _mm512_store_si512((__m512i*)pp, _t0); + _mm512_store_si512((__m512i*)(pp + 64), _t1); + + pp += 128; + p0 += A_hstep * 16; + } + if (max_kk >= 4) + { + _w_shift = _mm256_add_epi32(_mm512_extracti32x8_epi32(_w_shift_avx512, 0), _mm512_extracti32x8_epi32(_w_shift_avx512, 1)); + _mm256_store_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#else // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + __m512 _p4 = _mm512_load_ps(p0 + 64); + __m512 _p5 = _mm512_load_ps(p0 + 80); + __m512 _p6 = _mm512_load_ps(p0 + 96); + __m512 _p7 = _mm512_load_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _mm512_set1_ps(scales[i + ii])); + _p1 = _mm512_mul_ps(_p1, _mm512_set1_ps(scales[i + ii + 1])); + _p2 = _mm512_mul_ps(_p2, _mm512_set1_ps(scales[i + ii + 2])); + _p3 = _mm512_mul_ps(_p3, _mm512_set1_ps(scales[i + ii + 3])); + _p4 = _mm512_mul_ps(_p4, _mm512_set1_ps(scales[i + ii + 4])); + _p5 = _mm512_mul_ps(_p5, _mm512_set1_ps(scales[i + ii + 5])); + _p6 = _mm512_mul_ps(_p6, _mm512_set1_ps(scales[i + ii + 6])); + _p7 = _mm512_mul_ps(_p7, _mm512_set1_ps(scales[i + ii + 7])); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + transpose8x8_epi16(_pp0, _pp1, _pp2, _pp3, _pp4, _pp5, _pp6, _pp7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + __m512i _t1 = combine4x4_epi32(_pp4, _pp5, _pp6, _pp7); + + _mm512_store_si512((__m512i*)pp, _t0); + _mm512_store_si512((__m512i*)(pp + 64), _t1); + + pp += 128; + p0 += A_hstep * 16; + } +#endif // __AVX512VNNI__ + } +#endif // __AVX512F__ + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + __m256 _p4 = _mm256_load_ps(p0 + 32); + __m256 _p5 = _mm256_load_ps(p0 + 40); + __m256 _p6 = _mm256_load_ps(p0 + 48); + __m256 _p7 = _mm256_load_ps(p0 + 56); + + _p0 = _mm256_mul_ps(_p0, _mm256_set1_ps(scales[i + ii])); + _p1 = _mm256_mul_ps(_p1, _mm256_set1_ps(scales[i + ii + 1])); + _p2 = _mm256_mul_ps(_p2, _mm256_set1_ps(scales[i + ii + 2])); + _p3 = _mm256_mul_ps(_p3, _mm256_set1_ps(scales[i + ii + 3])); + _p4 = _mm256_mul_ps(_p4, _mm256_set1_ps(scales[i + ii + 4])); + _p5 = _mm256_mul_ps(_p5, _mm256_set1_ps(scales[i + ii + 5])); + _p6 = _mm256_mul_ps(_p6, _mm256_set1_ps(scales[i + ii + 6])); + _p7 = _mm256_mul_ps(_p7, _mm256_set1_ps(scales[i + ii + 7])); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + __m128i _pp2 = float2int8_avx(_p4, _p6); + __m128i _pp3 = float2int8_avx(_p5, _p7); + + __m256i _t0 = combine4x2_epi32(_pp0, _pp2); + __m256i _t1 = combine4x2_epi32(_pp1, _pp3); + + __m256i _t2 = _mm256_unpacklo_epi32(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi32(_t0, _t1); + _t0 = _mm256_unpacklo_epi64(_t2, _t3); + _t1 = _mm256_unpackhi_epi64(_t2, _t3); +#if !__AVXVNNIINT8__ + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _t0); + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _t1); +#endif // !__AVXVNNIINT8__ + _mm256_store_si256((__m256i*)pp, _t0); + _mm256_store_si256((__m256i*)(pp + 32), _t1); + + pp += 64; + p0 += A_hstep * 8; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm256_store_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + __m256 _p4 = _mm256_load_ps(p0 + 32); + __m256 _p5 = _mm256_load_ps(p0 + 40); + __m256 _p6 = _mm256_load_ps(p0 + 48); + __m256 _p7 = _mm256_load_ps(p0 + 56); + + _p0 = _mm256_mul_ps(_p0, _mm256_set1_ps(scales[i + ii])); + _p1 = _mm256_mul_ps(_p1, _mm256_set1_ps(scales[i + ii + 1])); + _p2 = _mm256_mul_ps(_p2, _mm256_set1_ps(scales[i + ii + 2])); + _p3 = _mm256_mul_ps(_p3, _mm256_set1_ps(scales[i + ii + 3])); + _p4 = _mm256_mul_ps(_p4, _mm256_set1_ps(scales[i + ii + 4])); + _p5 = _mm256_mul_ps(_p5, _mm256_set1_ps(scales[i + ii + 5])); + _p6 = _mm256_mul_ps(_p6, _mm256_set1_ps(scales[i + ii + 6])); + _p7 = _mm256_mul_ps(_p7, _mm256_set1_ps(scales[i + ii + 7])); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + __m128i _pp2 = float2int8_avx(_p4, _p6); + __m128i _pp3 = float2int8_avx(_p5, _p7); + +#if __AVX2__ + __m256i _t0 = combine4x2_epi32(_pp0, _pp2); + __m256i _t1 = combine4x2_epi32(_pp1, _pp3); + __m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1); + _t0 = _mm256_unpacklo_epi32(_t2, _t3); + _t1 = _mm256_unpackhi_epi32(_t2, _t3); + _t0 = _mm256_permute4x64_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm256_permute4x64_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm256_store_si256((__m256i*)pp, _t0); + _mm256_store_si256((__m256i*)(pp + 32), _t1); + pp += 64; +#else + __m128i _tt0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi16(_pp0, _pp1); + __m128i _tt2 = _mm_unpacklo_epi16(_pp2, _pp3); + __m128i _tt3 = _mm_unpackhi_epi16(_pp2, _pp3); + _pp0 = _mm_unpacklo_epi32(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi32(_tt0, _tt1); + _pp2 = _mm_unpacklo_epi32(_tt2, _tt3); + _pp3 = _mm_unpackhi_epi32(_tt2, _tt3); + __m256i _t0 = combine4x2_epi32(_pp0, _pp1); + __m256i _t1 = combine4x2_epi32(_pp2, _pp3); + _mm256_store_si256((__m256i*)pp, _t0); + _mm256_store_si256((__m256i*)pp1, _t1); + pp += 32; + pp1 += 32; +#endif + p0 += A_hstep * 8; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + } + if (elempack == 4) + { + __m256 _scales0 = _scales; + __m256 _scales1 = _scales; + __m256 _scales2 = _scales; + __m256 _scales3 = _scales; + transpose8x4_ps(_scales0, _scales1, _scales2, _scales3); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scales0); + _p1 = _mm256_mul_ps(_p1, _scales1); + _p2 = _mm256_mul_ps(_p2, _scales2); + _p3 = _mm256_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm256_store_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += A_hstep * 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm256_store_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scales0); + _p1 = _mm256_mul_ps(_p1, _scales1); + _p2 = _mm256_mul_ps(_p2, _scales2); + _p3 = _mm256_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); + +#if __AVX2__ + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); + __m128i _t2 = _mm_unpacklo_epi16(_t0, _t1); + __m128i _t3 = _mm_unpackhi_epi16(_t0, _t1); + _t0 = _mm_unpacklo_epi16(_t2, _t3); + _t1 = _mm_unpackhi_epi16(_t2, _t3); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + pp += 32; +#else + __m128i _si = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); + __m128i _t0 = _mm_shuffle_epi8(_pp0, _si); + __m128i _t1 = _mm_shuffle_epi8(_pp1, _si); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)pp1, _t1); + pp += 16; + pp1 += 16; +#endif + p0 += A_hstep * 4; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep); + __m256 _p2 = _mm256_loadu_ps(p0 + A_hstep * 2); + __m256 _p3 = _mm256_loadu_ps(p0 + A_hstep * 3); + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + _p2 = _mm256_mul_ps(_p2, _scales); + _p3 = _mm256_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm256_storeu_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += A_hstep * 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm256_storeu_si256((__m256i*)pp, _w_shift); + pp += 32; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep); + + _p0 = _mm256_mul_ps(_p0, _scales); + _p1 = _mm256_mul_ps(_p1, _scales); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + +#if __AVX2__ +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; +#else + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); + pp += 8; + pp1 += 8; +#endif + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m256 _p = _mm256_loadu_ps(p0); + + _p = _mm256_mul_ps(_p, _scales); + + int64_t v = float2int8_avx(_p); + +#if __AVX2__ + *(int64_t*)pp = v; + pp += 8; +#else + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); + pp += 4; + pp1 += 4; +#endif + p0 += A_hstep; + } + } + +#if !__AVX2__ + pp = pp1; + pp1 = pp + max_kk * 4; +#endif + } +#endif // __AVX__ + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m128i _w_shift = _mm_setzero_si128(); + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scales0 = _mm512_set1_ps(scales[i + ii]); + __m512 _scales1 = _mm512_set1_ps(scales[i + ii + 1]); + __m512 _scales2 = _mm512_set1_ps(scales[i + ii + 2]); + __m512 _scales3 = _mm512_set1_ps(scales[i + ii + 3]); + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp0); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp1); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp2); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp3); + + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + _mm_store_si128((__m128i*)(pp + 32), _pp2); + _mm_store_si128((__m128i*)(pp + 48), _pp3); + + pp += 64; + p0 += A_hstep * 16; + } + if (max_kk >= 4) + { + _mm_store_si128((__m128i*)pp, _w_shift); + pp += 16; + } +#else // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose8x4_epi16(_pp0, _pp1, _pp2, _pp3); + + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + _mm_store_si128((__m128i*)(pp + 32), _pp2); + _mm_store_si128((__m128i*)(pp + 48), _pp3); + + pp += 64; + p0 += A_hstep * 16; + } +#endif // __AVX512VNNI__ + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _scales0 = _mm256_set1_ps(scales[i + ii]); + __m256 _scales1 = _mm256_set1_ps(scales[i + ii + 1]); + __m256 _scales2 = _mm256_set1_ps(scales[i + ii + 2]); + __m256 _scales3 = _mm256_set1_ps(scales[i + ii + 3]); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scales0); + _p1 = _mm256_mul_ps(_p1, _scales1); + _p2 = _mm256_mul_ps(_p2, _scales2); + _p3 = _mm256_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _t0 = _mm_unpacklo_epi32(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi32(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi64(_t0, _t1); + _pp1 = _mm_unpackhi_epi64(_t0, _t1); +#if !__AVXVNNIINT8__ + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp0); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp1); +#endif // !__AVXVNNIINT8__ + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + + pp += 32; + p0 += A_hstep * 8; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm_store_si128((__m128i*)pp, _w_shift); + pp += 16; + } +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scales0); + _p1 = _mm256_mul_ps(_p1, _scales1); + _p2 = _mm256_mul_ps(_p2, _scales2); + _p3 = _mm256_mul_ps(_p3, _scales3); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi32(_t0, _t1); + _pp1 = _mm_unpackhi_epi32(_t0, _t1); + + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + + pp += 32; + p0 += A_hstep * 8; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _scales0 = _mm_set1_ps(scales[i + ii]); + __m128 _scales1 = _mm_set1_ps(scales[i + ii + 1]); + __m128 _scales2 = _mm_set1_ps(scales[i + ii + 2]); + __m128 _scales3 = _mm_set1_ps(scales[i + ii + 3]); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + __m128 _p2 = _mm_load_ps(p0 + 8); + __m128 _p3 = _mm_load_ps(p0 + 12); + + _p0 = _mm_mul_ps(_p0, _scales0); + _p1 = _mm_mul_ps(_p1, _scales1); + _p2 = _mm_mul_ps(_p2, _scales2); + _p3 = _mm_mul_ps(_p3, _scales3); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); +#if !__AVXVNNIINT8__ + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += A_hstep * 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm_store_si128((__m128i*)pp, _w_shift); + pp += 16; + } +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + __m128 _p2 = _mm_load_ps(p0 + 8); + __m128 _p3 = _mm_load_ps(p0 + 12); + + _p0 = _mm_mul_ps(_p0, _scales0); + _p1 = _mm_mul_ps(_p1, _scales1); + _p2 = _mm_mul_ps(_p2, _scales2); + _p3 = _mm_mul_ps(_p3, _scales3); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); + + _pp = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pp, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _pp = _mm_shuffle_epi32(_pp, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += A_hstep * 4; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + } + if (elempack == 1) + { + __m128 _scales = _mm_load_ps((const float*)scales + i + ii); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + A_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + A_hstep * 3); + + _p0 = _mm_mul_ps(_p0, _scales); + _p1 = _mm_mul_ps(_p1, _scales); + _p2 = _mm_mul_ps(_p2, _scales); + _p3 = _mm_mul_ps(_p3, _scales); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); + + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#if !__AVXVNNIINT8__ + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += A_hstep * 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + _mm_storeu_si128((__m128i*)pp, _w_shift); + pp += 16; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep); + _p0 = _mm_mul_ps(_p0, _scales); + _p1 = _mm_mul_ps(_p1, _scales); + __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m128 _p = _mm_loadu_ps(p0); + _p = _mm_mul_ps(_p, _scales); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scales0 = _mm512_set1_ps(scales[i + ii]); + __m512 _scales1 = _mm512_set1_ps(scales[i + ii + 1]); + + int kk = 0; +#if __AVX512VNNI__ + __m128i _w_shift = _mm_setzero_si128(); +#endif // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + +#if __AVX512VNNI__ + __m128i _t0 = _mm_unpacklo_epi32(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi32(_pp0, _pp1); + + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _t0); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _t1); +#else // __AVX512VNNI__ + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); +#endif // __AVX512VNNI__ + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += A_hstep * 16; + } +#if __AVX512VNNI__ + if (max_kk >= 4) + { + _w_shift = _mm_shuffle_epi32(_w_shift, _MM_SHUFFLE(3, 1, 2, 0)); + _w_shift = _mm_hadd_epi32(_w_shift, _w_shift); + _mm_storel_epi64((__m128i*)pp, _w_shift); + pp += 8; + } +#endif // __AVX512VNNI__ + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _scales0 = _mm256_set1_ps(scales[i + ii]); + __m256 _scales1 = _mm256_set1_ps(scales[i + ii + 1]); + + int kk = 0; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m128i _w_shift = _mm_setzero_si128(); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + + _p0 = _mm256_mul_ps(_p0, _scales0); + _p1 = _mm256_mul_ps(_p1, _scales1); + + __m128i _pp = float2int8_avx(_p0, _p1); + + _pp = _mm_shuffle_epi32(_pp, _MM_SHUFFLE(3, 1, 2, 0)); +#if __AVX512VNNI__ || __AVXVNNI__ +#if !__AVXVNNIINT8__ + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + _pp = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pp, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); +#endif // __AVX512VNNI__ || __AVXVNNI__ + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += A_hstep * 8; + } +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + _w_shift = _mm_shuffle_epi32(_w_shift, _MM_SHUFFLE(3, 1, 2, 0)); + _w_shift = _mm_hadd_epi32(_w_shift, _w_shift); + _mm_storel_epi64((__m128i*)pp, _w_shift); + pp += 8; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _scales0 = _mm_set1_ps(scales[i + ii]); + __m128 _scales1 = _mm_set1_ps(scales[i + ii + 1]); + + int kk = 0; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + int w_shift0 = 0; + int w_shift1 = 0; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + _p0 = _mm_mul_ps(_p0, _scales0); + _p1 = _mm_mul_ps(_p1, _scales1); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; +#if !__AVXVNNIINT8__ + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ + pp += 8; + p0 += A_hstep * 4; + } +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + } +#endif // __SSE2__ + if (elempack == 1) + { + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; + + int kk = 0; +#if __SSE2__ + __m128 _scales0 = _mm_set1_ps(scale0); + __m128 _scales1 = _mm_set1_ps(scale1); + __m128 _scales0011 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_scales0), _mm_castps_pd(_scales1))); +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + int w_shift0 = 0; + int w_shift1 = 0; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep))); + __m128 _p2 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 2))); + __m128 _p3 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 3))); + __m128 _p01 = _mm_unpacklo_ps(_p0, _p1); + __m128 _p23 = _mm_unpacklo_ps(_p2, _p3); + _p0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p01), _mm_castps_pd(_p23))); + _p1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p01), _mm_castps_pd(_p23))); + _p0 = _mm_mul_ps(_p0, _scales0); + _p1 = _mm_mul_ps(_p1, _scales1); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; +#if !__AVXVNNIINT8__ + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ + pp += 8; + p0 += A_hstep * 4; + } +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 1 < max_kk; kk += 2) + { + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + A_hstep))); + __m128 _p = _mm_unpacklo_ps(_p0, _p1); + _p = _mm_mul_ps(_p, _scales0011); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0 += A_hstep * 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale1); + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + +#if __AVX512VNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ + + const float scale = scales[i + ii]; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scale = _mm512_set1_ps(scales[i + ii]); + + int kk = 0; +#if __AVX512VNNI__ + __m128i _w_shift = _mm_setzero_si128(); +#endif // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p = _mm512_load_ps(p0); + _p = _mm512_mul_ps(_p, _scale); + __m128i _pp = float2int8_avx512(_p); +#if __AVX512VNNI__ + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp); +#endif // __AVX512VNNI__ + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += A_hstep * 16; + } +#if __AVX512VNNI__ + if (max_kk >= 4) + { + ((int*)pp)[0] = _mm_reduce_add_epi32(_w_shift); + pp += 4; + } +#endif // __AVX512VNNI__ + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _scale = _mm256_set1_ps(scales[i + ii]); + + int kk = 0; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + int w_shift = 0; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p = _mm256_load_ps(p0); + _p = _mm256_mul_ps(_p, _scale); + int64_t v = float2int8_avx(_p); + *(int64_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + w_shift += pp[4]; + w_shift += pp[5]; + w_shift += pp[6]; + w_shift += pp[7]; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp += 8; + p0 += A_hstep * 8; + } +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _scale = _mm_set1_ps(scales[i + ii]); + + int kk = 0; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + int w_shift = 0; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p = _mm_load_ps(p0); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp += 4; + p0 += A_hstep * 4; + } +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + } +#endif // __SSE2__ + if (elempack == 1) + { + int kk = 0; +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scales[i + ii]); +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + int w_shift = 0; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + for (; kk + 3 < max_kk; kk += 4) + { +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(A_hstep)); + + __m128 _p = _mm_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m128 _p = _mm_setr_ps(p0[0], p0[A_hstep], p0[A_hstep * 2], p0[A_hstep * 3]); +#endif + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp += 4; + p0 += A_hstep * 4; + } +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void compute_B_fp32_int8_scale(const Mat& B, float& scale) +{ + // NCNN_LOGE("compute_B_fp32_int8_scale"); + + float absmax = 0.f; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _absmax_avx512 = _mm512_setzero_ps(); +#endif // __AVX512F__ + __m256 _absmax_avx = _mm256_setzero_ps(); +#endif // __AVX__ + __m128 _absmax = _mm_setzero_ps(); +#endif // __SSE2__ + for (int i = 0; i < (B.dims == 3 ? B.c : B.h); i++) + { + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + const float* ptr = (const float*)B + i * B_hstep * B.elempack; + + const int size = B.w * B.elempack; + + int j = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; j + 15 < size; j += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p)); + ptr += 16; + } +#endif // __AVX512F__ + for (; j + 7 < size; j += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + _absmax_avx = _mm256_max_ps(_absmax_avx, abs256_ps(_p)); + ptr += 8; + } +#endif // __AVX__ + for (; j + 3 < size; j += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + _absmax = _mm_max_ps(_absmax, abs_ps(_p)); + ptr += 4; + } +#endif // __SSE2__ + for (; j < size; j++) + { + absmax = std::max(absmax, (float)fabsf(ptr[0])); + ptr++; + } + } +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + absmax = std::max(absmax, _mm512_comp_reduce_max_ps(_absmax_avx512)); +#endif // __AVX512F__ + absmax = std::max(absmax, _mm256_reduce_max_ps(_absmax_avx)); +#endif // __AVX__ + absmax = std::max(absmax, _mm_reduce_max_ps(_absmax)); +#endif + + scale = absmax == 0.f ? 1.f : 127.f / absmax; +} + +static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + pack_B_tile_fp32_to_int8_avx512vnni(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + pack_B_tile_fp32_to_int8_avxvnniint8(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + pack_B_tile_fp32_to_int8_avxvnni(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + pack_B_tile_fp32_to_int8_avx2(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("pack_B_tile_fp32_to_int8 %d %d %d", max_jj, max_kk, elempack); + + signed char* pp = BT; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + + __m512 _scale = _mm512_set1_ps(scale); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ + + if (elempack == 16) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + // transpose16x2_epi8 + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += 32; + } + for (; kk < max_kk; kk++) + { + __m512 _p = _mm512_load_ps(p0); + + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 16; + } + } + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + B_hstep * 8); + __m512 _p3 = _mm512_loadu_ps(p0 + B_hstep * 8 + 16); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + __m128i _t2 = _mm_unpacklo_epi8(_pp2, _pp3); + __m128i _t3 = _mm_unpackhi_epi8(_pp2, _pp3); + _pp0 = _mm_unpacklo_epi8(_t0, _t1); + _pp1 = _mm_unpackhi_epi8(_t0, _t1); + _pp2 = _mm_unpacklo_epi8(_t2, _t3); + _pp3 = _mm_unpackhi_epi8(_t2, _t3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 32; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + B_hstep * 8); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp0 = _mm_shuffle_epi8(_pp0, _si); + _pp1 = _mm_shuffle_epi8(_pp1, _si); + + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + + pp += 32; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep * 8); + + __m512 _p = combine8x2_ps(_p0, _p1); + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 8; + } + } + if (elempack == 4) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + B_hstep * 4); + __m512 _p2 = _mm512_loadu_ps(p0 + B_hstep * 8); + __m512 _p3 = _mm512_loadu_ps(p0 + B_hstep * 12); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm512_shuffle_epi8(_pp, _mm512_broadcast_i32x4(_si)); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 16; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep * 4); + __m256 _p2 = _mm256_loadu_ps(p0 + B_hstep * 8); + __m256 _p3 = _mm256_loadu_ps(p0 + B_hstep * 12); + + __m512 _p01 = combine8x2_ps(_p0, _p1); + __m512 _p23 = combine8x2_ps(_p2, _p3); + + _p01 = _mm512_mul_ps(_p01, _scale); + _p23 = _mm512_mul_ps(_p23, _scale); + + __m128i _pp0 = float2int8_avx512(_p01); + __m128i _pp1 = float2int8_avx512(_p23); + + __m128i _si = _mm_setr_epi8(0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15); + _pp0 = _mm_shuffle_epi8(_pp0, _si); + _pp1 = _mm_shuffle_epi8(_pp1, _si); + + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + + pp += 32; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep * 4); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 8); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 12); + + __m512 _p = combine4x4_ps(_p0, _p1, _p2, _p3); + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 3); + __m128 _p4 = _mm_loadu_ps(p0 + B_hstep * 4); + __m128 _p5 = _mm_loadu_ps(p0 + B_hstep * 5); + __m128 _p6 = _mm_loadu_ps(p0 + B_hstep * 6); + __m128 _p7 = _mm_loadu_ps(p0 + B_hstep * 7); + __m128 _p8 = _mm_loadu_ps(p0 + B_hstep * 8); + __m128 _p9 = _mm_loadu_ps(p0 + B_hstep * 9); + __m128 _pa = _mm_loadu_ps(p0 + B_hstep * 10); + __m128 _pb = _mm_loadu_ps(p0 + B_hstep * 11); + __m128 _pc = _mm_loadu_ps(p0 + B_hstep * 12); + __m128 _pd = _mm_loadu_ps(p0 + B_hstep * 13); + __m128 _pe = _mm_loadu_ps(p0 + B_hstep * 14); + __m128 _pf = _mm_loadu_ps(p0 + B_hstep * 15); + + __m512 _t0 = combine4x4_ps(_p0, _p1, _p2, _p3); + __m512 _t1 = combine4x4_ps(_p4, _p5, _p6, _p7); + __m512 _t2 = combine4x4_ps(_p8, _p9, _pa, _pb); + __m512 _t3 = combine4x4_ps(_pc, _pd, _pe, _pf); + + _t0 = _mm512_mul_ps(_t0, _scale); + _t1 = _mm512_mul_ps(_t1, _scale); + _t2 = _mm512_mul_ps(_t2, _scale); + _t3 = _mm512_mul_ps(_t3, _scale); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + __m128i _pp2 = float2int8_avx512(_t2); + __m128i _pp3 = float2int8_avx512(_t3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 4; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(B_hstep)); + + __m512 _p0 = _mm512_i32gather_ps(_vindex, p0, sizeof(float)); + __m512 _p1 = _mm512_i32gather_ps(_vindex, p0 + 1, sizeof(float)); + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(B_hstep)); + + __m512 _p = _mm512_i32gather_ps(_vindex, p0, sizeof(float)); + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0++; + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + +#if __AVX__ + __m256 _scale = _mm256_set1_ps(scale); +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) +#else + __m128 _scale = _mm_set1_ps(scale); +#endif // __AVX__ + +#if __AVX__ + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _pp = _mm256_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + _mm256_storeu_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += 32; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + __m256 _p = _mm256_load_ps(p0); + + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); + + *(int64_t*)pp = v; + pp += 8; + p0 += 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + B_hstep * 4); + __m256 _p3 = _mm256_loadu_ps(p0 + B_hstep * 4 + 8); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _pp = _mm256_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm256_shuffle_epi8(_pp, combine4x2_epi32(_si, _si)); + + _mm256_storeu_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += 16; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX__ + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep * 4); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#else // __AVX__ + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + __m128 _p2 = _mm_load_ps(p0 + B_hstep * 4); + __m128 _p3 = _mm_load_ps(p0 + B_hstep * 4 + 4); + + __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + __m128 _t2 = _mm_unpacklo_ps(_p2, _p3); + __m128 _t3 = _mm_unpackhi_ps(_p2, _p3); + + _t0 = _mm_mul_ps(_t0, _scale); + _t1 = _mm_mul_ps(_t1, _scale); + _t2 = _mm_mul_ps(_t2, _scale); + _t3 = _mm_mul_ps(_t3, _scale); + + __m128i _pp = float2int8_sse(_t0, _t1, _t2, _t3); +#endif // __AVX__ + +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + B_hstep * 4); + +#if __AVX__ + __m256 _p = combine4x2_ps(_p0, _p1); + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); +#else // __AVX__ + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + + int64_t v = float2int8_sse(_p0, _p1); +#endif // __AVX__ + + *(int64_t*)pp = v; + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 3); + __m128 _p4 = _mm_loadu_ps(p0 + B_hstep * 4); + __m128 _p5 = _mm_loadu_ps(p0 + B_hstep * 5); + __m128 _p6 = _mm_loadu_ps(p0 + B_hstep * 6); + __m128 _p7 = _mm_loadu_ps(p0 + B_hstep * 7); + + __m256 _t0 = combine4x2_ps(_p0, _p1); + __m256 _t1 = combine4x2_ps(_p2, _p3); + __m256 _t2 = combine4x2_ps(_p4, _p5); + __m256 _t3 = combine4x2_ps(_p6, _p7); + + _t0 = _mm256_mul_ps(_t0, _scale); + _t1 = _mm256_mul_ps(_t1, _scale); + _t2 = _mm256_mul_ps(_t2, _scale); + _t3 = _mm256_mul_ps(_t3, _scale); + + __m128i _pp0 = float2int8_avx(_t0, _t1); + __m128i _pp1 = float2int8_avx(_t2, _t3); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _pp = _mm256_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + _mm256_storeu_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += 4; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX__ +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(B_hstep)); + + __m256 _p0 = _mm256_i32gather_ps(p0, _vindex, sizeof(float)); + __m256 _p1 = _mm256_i32gather_ps(p0 + 1, _vindex, sizeof(float)); +#else + __m256 _p0 = _mm256_setr_ps(p0[0], p0[1], p0[B_hstep], p0[B_hstep + 1], p0[B_hstep * 2], p0[B_hstep * 2 + 1], p0[B_hstep * 3], p0[B_hstep * 3 + 1]); + __m256 _p1 = _mm256_setr_ps(p0[B_hstep * 4], p0[B_hstep * 4 + 1], p0[B_hstep * 5], p0[B_hstep * 5 + 1], p0[B_hstep * 6], p0[B_hstep * 6 + 1], p0[B_hstep * 7], p0[B_hstep * 7 + 1]); +#endif + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + +#if __AVX2__ + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#endif +#else // __AVX__ + __m128 _p0 = _mm_setr_ps(p0[0], p0[1], p0[B_hstep], p0[B_hstep + 1]); + __m128 _p1 = _mm_setr_ps(p0[B_hstep * 2], p0[B_hstep * 2 + 1], p0[B_hstep * 3], p0[B_hstep * 3 + 1]); + __m128 _p2 = _mm_setr_ps(p0[B_hstep * 4], p0[B_hstep * 4 + 1], p0[B_hstep * 5], p0[B_hstep * 5 + 1]); + __m128 _p3 = _mm_setr_ps(p0[B_hstep * 6], p0[B_hstep * 6 + 1], p0[B_hstep * 7], p0[B_hstep * 7 + 1]); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); +#endif // __AVX__ + +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { +#if __AVX__ +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(B_hstep)); + + __m256 _p = _mm256_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m256 _p = _mm256_setr_ps(p0[0], p0[B_hstep], p0[B_hstep * 2], p0[B_hstep * 3], p0[B_hstep * 4], p0[B_hstep * 5], p0[B_hstep * 6], p0[B_hstep * 7]); +#endif + + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); +#else // __AVX__ + __m128 _p0 = _mm_setr_ps(p0[0], p0[B_hstep], p0[B_hstep * 2], p0[B_hstep * 3]); + __m128 _p1 = _mm_setr_ps(p0[B_hstep * 4], p0[B_hstep * 5], p0[B_hstep * 6], p0[B_hstep * 7]); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + + int64_t v = float2int8_sse(_p0, _p1); +#endif // __AVX__ + + *(int64_t*)pp = v; + pp += 8; + p0++; + } + } + } +#else // defined(__x86_64__) || defined(_M_X64) +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scale = _mm512_set1_ps(scale); +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ + + for (; jj + 15 < max_jj; jj += 16) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + + signed char* pp1 = pp + max_kk * 4; + signed char* pp2 = pp + max_kk * 8; + signed char* pp3 = pp + max_kk * 12; + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + _pp0 = _mm_add_epi8(_pp0, _v127); + _pp1 = _mm_add_epi8(_pp1, _v127); + _pp2 = _mm_add_epi8(_pp2, _v127); + _pp3 = _mm_add_epi8(_pp3, _v127); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + _mm_storeu_si128((__m128i*)pp, _pp0); + _mm_storeu_si128((__m128i*)pp1, _pp1); + _mm_storeu_si128((__m128i*)pp2, _pp2); + _mm_storeu_si128((__m128i*)pp3, _pp3); + + pp += 16; + pp1 += 16; + pp2 += 16; + pp3 += 16; + p0 += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_t0)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_t0)); + _mm_storel_pd((double*)pp2, _mm_castsi128_pd(_t1)); + _mm_storeh_pd((double*)pp3, _mm_castsi128_pd(_t1)); + + pp += 8; + pp1 += 8; + pp2 += 8; + pp3 += 8; + p0 += 32; + } + for (; kk < max_kk; kk++) + { + __m512 _p = _mm512_load_ps(p0); + + _p = _mm512_mul_ps(_p, _scale); + + __m128i _v = float2int8_avx512(_p); + + *(int*)pp = _mm_extract_epi32(_v, 0); + *(int*)pp1 = _mm_extract_epi32(_v, 1); + *(int*)pp2 = _mm_extract_epi32(_v, 2); + *(int*)pp3 = _mm_extract_epi32(_v, 3); + + pp += 4; + pp1 += 4; + pp2 += 4; + pp3 += 4; + p0 += 16; + } + + pp = pp3; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _scale = _mm256_set1_ps(scale); +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + + signed char* pp1 = pp + max_kk * 4; + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); +#if !__AVXVNNIINT8__ + _pp0 = _mm_add_epi8(_pp0, _v127); + _pp1 = _mm_add_epi8(_pp1, _v127); +#endif // !__AVXVNNIINT8__ + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + _mm_storeu_si128((__m128i*)pp, _pp0); + _mm_storeu_si128((__m128i*)pp1, _pp1); + + pp += 16; + pp1 += 16; + p0 += 32; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + + _mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp)); + _mm_storeh_pd((double*)pp1, _mm_castsi128_pd(_pp)); + pp += 8; + pp1 += 8; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + __m256 _p = _mm256_load_ps(p0); + + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); + + *(int32_t*)pp = (int32_t)v; + *(int32_t*)pp1 = (int32_t)(v >> 32); + + pp += 4; + pp1 += 4; + p0 += 8; + } + + pp = pp1; + } + } +#endif // __AVX__ +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + + __m128 _scale = _mm_set1_ps(scale); +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + + if (elempack == 4) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + __m128 _p2 = _mm_load_ps(p0 + 8); + __m128 _p3 = _mm_load_ps(p0 + 12); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); + + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#if !__AVXVNNIINT8__ + _pp = _mm_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 16; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + __m128 _p = _mm_load_ps(p0); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 3); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); +#if !__AVXVNNIINT8__ + _pp = _mm_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 4; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(B_hstep)); + + __m128 _t0 = _mm_i32gather_ps(p0, _vindex, sizeof(float)); + __m128 _t1 = _mm_i32gather_ps(p0 + 1, _vindex, sizeof(float)); + __m128 _p0 = _mm_unpacklo_ps(_t0, _t1); + __m128 _p1 = _mm_unpackhi_ps(_t0, _t1); +#else + __m128 _p0 = _mm_setr_ps(p0[0], p0[1], p0[B_hstep], p0[B_hstep + 1]); + __m128 _p1 = _mm_setr_ps(p0[B_hstep * 2], p0[B_hstep * 2 + 1], p0[B_hstep * 3], p0[B_hstep * 3 + 1]); +#endif + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(B_hstep)); + + __m128 _p = _mm_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m128 _p = _mm_setr_ps(p0[0], p0[B_hstep], p0[B_hstep * 2], p0[B_hstep * 3]); +#endif + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0++; + } + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); +#endif // __SSE2__ + + // if (elempack == 1) + { + int kk = 0; +#if __SSE2__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep); + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; +#if !__AVXVNNIINT8__ + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; + pp[4] += 127; + pp[5] += 127; + pp[6] += 127; + pp[7] += 127; +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep))); + __m128 _p = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0 += 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[B_hstep] * scale); + pp += 2; + p0++; + } + } + } + for (; jj < max_jj; jj += 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); +#endif // __SSE2__ + + // if (elempack == 1) + { + int kk = 0; +#if __SSE2__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p = _mm_loadu_ps(p0); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp += 4; + p0 += 4; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + transpose_pack_B_tile_fp32_to_int8_avx512vnni(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + transpose_pack_B_tile_fp32_to_int8_avxvnniint8(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + transpose_pack_B_tile_fp32_to_int8_avxvnni(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + transpose_pack_B_tile_fp32_to_int8_avx2(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("transpose_pack_B_tile_fp32_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + + __m512 _scale = _mm512_set1_ps(scale); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ + + if (elempack == 16) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + __m512 _p4 = _mm512_load_ps(p0 + 64); + __m512 _p5 = _mm512_load_ps(p0 + 80); + __m512 _p6 = _mm512_load_ps(p0 + 96); + __m512 _p7 = _mm512_load_ps(p0 + 112); + __m512 _p8 = _mm512_load_ps(p0 + 128); + __m512 _p9 = _mm512_load_ps(p0 + 128 + 16); + __m512 _pa = _mm512_load_ps(p0 + 128 + 32); + __m512 _pb = _mm512_load_ps(p0 + 128 + 48); + __m512 _pc = _mm512_load_ps(p0 + 128 + 64); + __m512 _pd = _mm512_load_ps(p0 + 128 + 80); + __m512 _pe = _mm512_load_ps(p0 + 128 + 96); + __m512 _pf = _mm512_load_ps(p0 + 128 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + _p8 = _mm512_mul_ps(_p8, _scale); + _p9 = _mm512_mul_ps(_p9, _scale); + _pa = _mm512_mul_ps(_pa, _scale); + _pb = _mm512_mul_ps(_pb, _scale); + _pc = _mm512_mul_ps(_pc, _scale); + _pd = _mm512_mul_ps(_pd, _scale); + _pe = _mm512_mul_ps(_pe, _scale); + _pf = _mm512_mul_ps(_pf, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + __m128i _pp8 = float2int8_avx512(_p8); + __m128i _pp9 = float2int8_avx512(_p9); + __m128i _ppa = float2int8_avx512(_pa); + __m128i _ppb = float2int8_avx512(_pb); + __m128i _ppc = float2int8_avx512(_pc); + __m128i _ppd = float2int8_avx512(_pd); + __m128i _ppe = float2int8_avx512(_pe); + __m128i _ppf = float2int8_avx512(_pf); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp4, _pp8, _ppc); + __m512i _t1 = combine4x4_epi32(_pp1, _pp5, _pp9, _ppd); + __m512i _t2 = combine4x4_epi32(_pp2, _pp6, _ppa, _ppe); + __m512i _t3 = combine4x4_epi32(_pp3, _pp7, _ppb, _ppf); + + __m512i _t4 = _mm512_unpacklo_epi32(_t0, _t1); + __m512i _t5 = _mm512_unpackhi_epi32(_t0, _t1); + __m512i _t6 = _mm512_unpacklo_epi32(_t2, _t3); + __m512i _t7 = _mm512_unpackhi_epi32(_t2, _t3); + _t0 = _mm512_unpacklo_epi64(_t4, _t6); + _t1 = _mm512_unpackhi_epi64(_t4, _t6); + _t2 = _mm512_unpacklo_epi64(_t5, _t7); + _t3 = _mm512_unpackhi_epi64(_t5, _t7); + + _t0 = _mm512_add_epi8(_t0, _v127); + _t1 = _mm512_add_epi8(_t1, _v127); + _t2 = _mm512_add_epi8(_t2, _v127); + _t3 = _mm512_add_epi8(_t3, _v127); + + _mm512_store_si512((__m512i*)pp, _t0); + _mm512_store_si512((__m512i*)(pp + 64), _t1); + _mm512_store_si512((__m512i*)(pp + 128), _t2); + _mm512_store_si512((__m512i*)(pp + 192), _t3); + + pp += 256; + p0 += B_hstep * 16; + } +#else // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + __m512 _p4 = _mm512_load_ps(p0 + 64); + __m512 _p5 = _mm512_load_ps(p0 + 80); + __m512 _p6 = _mm512_load_ps(p0 + 96); + __m512 _p7 = _mm512_load_ps(p0 + 112); + __m512 _p8 = _mm512_load_ps(p0 + 128); + __m512 _p9 = _mm512_load_ps(p0 + 128 + 16); + __m512 _pa = _mm512_load_ps(p0 + 128 + 32); + __m512 _pb = _mm512_load_ps(p0 + 128 + 48); + __m512 _pc = _mm512_load_ps(p0 + 128 + 64); + __m512 _pd = _mm512_load_ps(p0 + 128 + 80); + __m512 _pe = _mm512_load_ps(p0 + 128 + 96); + __m512 _pf = _mm512_load_ps(p0 + 128 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + _p8 = _mm512_mul_ps(_p8, _scale); + _p9 = _mm512_mul_ps(_p9, _scale); + _pa = _mm512_mul_ps(_pa, _scale); + _pb = _mm512_mul_ps(_pb, _scale); + _pc = _mm512_mul_ps(_pc, _scale); + _pd = _mm512_mul_ps(_pd, _scale); + _pe = _mm512_mul_ps(_pe, _scale); + _pf = _mm512_mul_ps(_pf, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + __m128i _pp8 = float2int8_avx512(_p8); + __m128i _pp9 = float2int8_avx512(_p9); + __m128i _ppa = float2int8_avx512(_pa); + __m128i _ppb = float2int8_avx512(_pb); + __m128i _ppc = float2int8_avx512(_pc); + __m128i _ppd = float2int8_avx512(_pd); + __m128i _ppe = float2int8_avx512(_pe); + __m128i _ppf = float2int8_avx512(_pf); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp4, _pp8, _ppc); + __m512i _t1 = combine4x4_epi32(_pp1, _pp5, _pp9, _ppd); + __m512i _t2 = combine4x4_epi32(_pp2, _pp6, _ppa, _ppe); + __m512i _t3 = combine4x4_epi32(_pp3, _pp7, _ppb, _ppf); + + __m512i _t4 = _mm512_unpacklo_epi16(_t0, _t1); + __m512i _t5 = _mm512_unpackhi_epi16(_t0, _t1); + __m512i _t6 = _mm512_unpacklo_epi16(_t2, _t3); + __m512i _t7 = _mm512_unpackhi_epi16(_t2, _t3); + + _t0 = _mm512_unpacklo_epi32(_t4, _t6); + _t1 = _mm512_unpackhi_epi32(_t4, _t6); + _t2 = _mm512_unpacklo_epi32(_t5, _t7); + _t3 = _mm512_unpackhi_epi32(_t5, _t7); + + _t0 = _mm512_permutex_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_permutex_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + _t2 = _mm512_permutex_epi64(_t2, _MM_SHUFFLE(3, 1, 2, 0)); + _t3 = _mm512_permutex_epi64(_t3, _MM_SHUFFLE(3, 1, 2, 0)); + _t0 = _mm512_shuffle_i32x4(_t0, _t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_shuffle_i32x4(_t1, _t1, _MM_SHUFFLE(3, 1, 2, 0)); + _t2 = _mm512_shuffle_i32x4(_t2, _t2, _MM_SHUFFLE(3, 1, 2, 0)); + _t3 = _mm512_shuffle_i32x4(_t3, _t3, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_store_si512((__m512i*)pp, _t0); + _mm512_store_si512((__m512i*)(pp + 64), _t1); + _mm512_store_si512((__m512i*)(pp + 128), _t2); + _mm512_store_si512((__m512i*)(pp + 192), _t3); + + pp += 256; + p0 += B_hstep * 16; + } +#endif // __AVX512VNNI__ + } + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp2, _pp4, _pp6); + __m512i _t1 = combine4x4_epi32(_pp1, _pp3, _pp5, _pp7); + + __m512i _t2 = _mm512_unpacklo_epi32(_t0, _t1); + __m512i _t3 = _mm512_unpackhi_epi32(_t0, _t1); + __m512i _ppa = _mm512_unpacklo_epi32(_t2, _t3); + __m512i _ppb = _mm512_unpackhi_epi32(_t2, _t3); + + _ppa = _mm512_add_epi8(_ppa, _v127); + _ppb = _mm512_add_epi8(_ppb, _v127); + + _mm512_store_si512((__m512i*)pp, _ppa); + _mm512_store_si512((__m512i*)(pp + 64), _ppb); + + pp += 128; + p0 += B_hstep * 8; + } +#else // __AVX512VNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp2, _pp4, _pp6); + __m512i _t1 = combine4x4_epi32(_pp1, _pp3, _pp5, _pp7); + + __m512i _t2 = _mm512_unpacklo_epi16(_t0, _t1); + __m512i _t3 = _mm512_unpackhi_epi16(_t0, _t1); + _t0 = _mm512_unpacklo_epi16(_t2, _t3); + _t1 = _mm512_unpackhi_epi16(_t2, _t3); + _t0 = _mm512_permutex_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_permutex_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + __m512i _ppa = _mm512_shuffle_i32x4(_t0, _t0, _MM_SHUFFLE(3, 1, 2, 0)); + __m512i _ppb = _mm512_shuffle_i32x4(_t1, _t1, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_store_si512((__m512i*)pp, _ppa); + _mm512_store_si512((__m512i*)(pp + 64), _ppb); + + pp += 128; + p0 += B_hstep * 8; + } +#endif // __AVX512VNNI__ + } + if (elempack == 4) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_store_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += B_hstep * 4; + } +#else // __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m256i _pp02 = combine4x2_epi32(_pp0, _pp2); + __m256i _pp13 = combine4x2_epi32(_pp1, _pp3); + + __m256i _t0 = _mm256_unpacklo_epi16(_pp02, _pp13); + __m256i _t1 = _mm256_unpackhi_epi16(_pp02, _pp13); + __m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1); + _t0 = _mm256_unpacklo_epi16(_t2, _t3); + _t1 = _mm256_unpackhi_epi16(_t2, _t3); + + _mm256_store_si256((__m256i*)pp, _t0); + _mm256_store_si256((__m256i*)(pp + 32), _t1); + + pp += 64; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + B_hstep); + __m512 _p2 = _mm512_loadu_ps(p0 + B_hstep * 2); + __m512 _p3 = _mm512_loadu_ps(p0 + B_hstep * 3); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp = _mm512_add_epi8(_pp, _v127); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + B_hstep); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + // transpose16x2_epi8 + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m512 _p = _mm512_loadu_ps(p0); + + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += B_hstep; + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scale = _mm512_set1_ps(scale); + + int kk = 0; +#if __AVX512VNNI__ + __m512i _v127_avx512 = _mm512_set1_epi8(127); + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + __m512 _p4 = _mm512_load_ps(p0 + 64); + __m512 _p5 = _mm512_load_ps(p0 + 80); + __m512 _p6 = _mm512_load_ps(p0 + 96); + __m512 _p7 = _mm512_load_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + transpose4x8_epi32(_pp0, _pp1, _pp2, _pp3, _pp4, _pp5, _pp6, _pp7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + __m512i _t1 = combine4x4_epi32(_pp4, _pp5, _pp6, _pp7); + + _t0 = _mm512_add_epi8(_t0, _v127_avx512); + _t1 = _mm512_add_epi8(_t1, _v127_avx512); + + _mm512_store_si512((__m512i*)pp, _t0); + _mm512_store_si512((__m512i*)(pp + 64), _t1); + + pp += 128; + p0 += B_hstep * 16; + } +#else // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + __m512 _p4 = _mm512_load_ps(p0 + 64); + __m512 _p5 = _mm512_load_ps(p0 + 80); + __m512 _p6 = _mm512_load_ps(p0 + 96); + __m512 _p7 = _mm512_load_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + _p4 = _mm512_mul_ps(_p4, _scale); + _p5 = _mm512_mul_ps(_p5, _scale); + _p6 = _mm512_mul_ps(_p6, _scale); + _p7 = _mm512_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + transpose8x8_epi16(_pp0, _pp1, _pp2, _pp3, _pp4, _pp5, _pp6, _pp7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + __m512i _t1 = combine4x4_epi32(_pp4, _pp5, _pp6, _pp7); + + _mm512_store_si512((__m512i*)pp, _t0); + _mm512_store_si512((__m512i*)(pp + 64), _t1); + + pp += 128; + p0 += B_hstep * 16; + } +#endif // __AVX512VNNI__ + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _scale = _mm256_set1_ps(scale); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + __m256 _p4 = _mm256_load_ps(p0 + 32); + __m256 _p5 = _mm256_load_ps(p0 + 40); + __m256 _p6 = _mm256_load_ps(p0 + 48); + __m256 _p7 = _mm256_load_ps(p0 + 56); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + _p4 = _mm256_mul_ps(_p4, _scale); + _p5 = _mm256_mul_ps(_p5, _scale); + _p6 = _mm256_mul_ps(_p6, _scale); + _p7 = _mm256_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + __m128i _pp2 = float2int8_avx(_p4, _p6); + __m128i _pp3 = float2int8_avx(_p5, _p7); + + __m256i _t0 = combine4x2_epi32(_pp0, _pp2); + __m256i _t1 = combine4x2_epi32(_pp1, _pp3); + + __m256i _t2 = _mm256_unpacklo_epi32(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi32(_t0, _t1); + _t0 = _mm256_unpacklo_epi64(_t2, _t3); + _t1 = _mm256_unpackhi_epi64(_t2, _t3); +#if !__AVXVNNIINT8__ + _t0 = _mm256_add_epi8(_t0, _v127); + _t1 = _mm256_add_epi8(_t1, _v127); +#endif // !__AVXVNNIINT8__ + _mm256_store_si256((__m256i*)pp, _t0); + _mm256_store_si256((__m256i*)(pp + 32), _t1); + + pp += 64; + p0 += B_hstep * 8; + } +#else // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + __m256 _p4 = _mm256_load_ps(p0 + 32); + __m256 _p5 = _mm256_load_ps(p0 + 40); + __m256 _p6 = _mm256_load_ps(p0 + 48); + __m256 _p7 = _mm256_load_ps(p0 + 56); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + _p4 = _mm256_mul_ps(_p4, _scale); + _p5 = _mm256_mul_ps(_p5, _scale); + _p6 = _mm256_mul_ps(_p6, _scale); + _p7 = _mm256_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + __m128i _pp2 = float2int8_avx(_p4, _p6); + __m128i _pp3 = float2int8_avx(_p5, _p7); + +#if __AVX2__ + __m256i _t0 = combine4x2_epi32(_pp0, _pp2); + __m256i _t1 = combine4x2_epi32(_pp1, _pp3); + __m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1); + _t0 = _mm256_unpacklo_epi32(_t2, _t3); + _t1 = _mm256_unpackhi_epi32(_t2, _t3); + _t0 = _mm256_permute4x64_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm256_permute4x64_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); +#else + __m128i _tt0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi16(_pp0, _pp1); + __m128i _tt2 = _mm_unpacklo_epi16(_pp2, _pp3); + __m128i _tt3 = _mm_unpackhi_epi16(_pp2, _pp3); + _pp0 = _mm_unpacklo_epi32(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi32(_tt0, _tt1); + _pp2 = _mm_unpacklo_epi32(_tt2, _tt3); + _pp3 = _mm_unpackhi_epi32(_tt2, _tt3); + _tt0 = _mm_unpacklo_epi64(_pp0, _pp2); + _tt1 = _mm_unpackhi_epi64(_pp0, _pp2); + _tt2 = _mm_unpacklo_epi64(_pp1, _pp3); + _tt3 = _mm_unpackhi_epi64(_pp1, _pp3); + __m256i _t0 = combine4x2_epi32(_tt0, _tt1); + __m256i _t1 = combine4x2_epi32(_tt2, _tt3); +#endif + _mm256_store_si256((__m256i*)pp, _t0); + _mm256_store_si256((__m256i*)(pp + 32), _t1); + + pp += 64; + p0 += B_hstep * 8; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + } +#endif // __AVX__ + if (elempack == 4) + { +#if __AVX__ + __m256 _scale = _mm256_set1_ps(scale); +#else + __m128 _scale = _mm_set1_ps(scale); +#endif + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _pp = _mm256_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + _mm256_store_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += B_hstep * 4; + } +#else // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { +#if __AVX__ + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + 8); + __m256 _p2 = _mm256_loadu_ps(p0 + 16); + __m256 _p3 = _mm256_loadu_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p1); + __m128i _pp1 = float2int8_avx(_p2, _p3); +#else + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + __m128 _p2 = _mm_load_ps(p0 + 8); + __m128 _p3 = _mm_load_ps(p0 + 12); + __m128 _p4 = _mm_load_ps(p0 + 16); + __m128 _p5 = _mm_load_ps(p0 + 20); + __m128 _p6 = _mm_load_ps(p0 + 24); + __m128 _p7 = _mm_load_ps(p0 + 28); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + _p4 = _mm_mul_ps(_p4, _scale); + _p5 = _mm_mul_ps(_p5, _scale); + _p6 = _mm_mul_ps(_p6, _scale); + _p7 = _mm_mul_ps(_p7, _scale); + + __m128i _pp0 = float2int8_sse(_p0, _p1, _p2, _p3); + __m128i _pp1 = float2int8_sse(_p4, _p5, _p6, _p7); +#endif + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); + __m128i _t2 = _mm_unpacklo_epi16(_t0, _t1); + __m128i _t3 = _mm_unpackhi_epi16(_t0, _t1); + _t0 = _mm_unpacklo_epi16(_t2, _t3); + _t1 = _mm_unpackhi_epi16(_t2, _t3); + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + } + if (elempack == 1) + { +#if __AVX__ + __m256 _scale = _mm256_set1_ps(scale); +#else + __m128 _scale = _mm_set1_ps(scale); +#endif + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep); + __m256 _p2 = _mm256_loadu_ps(p0 + B_hstep * 2); + __m256 _p3 = _mm256_loadu_ps(p0 + B_hstep * 3); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi16(_tt0, _tt1); + _pp1 = _mm_unpackhi_epi16(_tt0, _tt1); + + __m256i _pp = combine4x2_epi32(_pp0, _pp1); +#if !__AVXVNNIINT8__ + _pp = _mm256_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + _mm256_storeu_si256((__m256i*)pp, _pp); + + pp += 32; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX__ + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + B_hstep); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + + __m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#else + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + 4); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep + 4); + + __m128 _t0 = _mm_unpacklo_ps(_p0, _p2); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p2); + __m128 _t2 = _mm_unpacklo_ps(_p1, _p3); + __m128 _t3 = _mm_unpackhi_ps(_p1, _p3); + + _t0 = _mm_mul_ps(_t0, _scale); + _t1 = _mm_mul_ps(_t1, _scale); + _t2 = _mm_mul_ps(_t2, _scale); + _t3 = _mm_mul_ps(_t3, _scale); + + __m128i _pp = float2int8_sse(_t0, _t1, _t2, _t3); +#endif + +#if __AVX512F__ + _mm_store_si128((__m128i*)pp, _pp); +#else + _mm_storeu_si128((__m128i*)pp, _pp); +#endif + + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { +#if __AVX__ + __m256 _p = _mm256_loadu_ps(p0); + + _p = _mm256_mul_ps(_p, _scale); + + int64_t v = float2int8_avx(_p); +#else + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + 4); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + + int64_t v = float2int8_sse(_p0, _p1); +#endif + *(int64_t*)pp = v; + + pp += 8; + p0 += B_hstep; + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scale = _mm512_set1_ps(scale); + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _pp0 = _mm_add_epi8(_pp0, _v127); + _pp1 = _mm_add_epi8(_pp1, _v127); + _pp2 = _mm_add_epi8(_pp2, _v127); + _pp3 = _mm_add_epi8(_pp3, _v127); + + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + _mm_store_si128((__m128i*)(pp + 32), _pp2); + _mm_store_si128((__m128i*)(pp + 48), _pp3); + + pp += 64; + p0 += B_hstep * 16; + } +#else // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + __m512 _p2 = _mm512_load_ps(p0 + 32); + __m512 _p3 = _mm512_load_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + _p2 = _mm512_mul_ps(_p2, _scale); + _p3 = _mm512_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose8x4_epi16(_pp0, _pp1, _pp2, _pp3); + + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + _mm_store_si128((__m128i*)(pp + 32), _pp2); + _mm_store_si128((__m128i*)(pp + 48), _pp3); + + pp += 64; + p0 += B_hstep * 16; + } +#endif // __AVX512VNNI__ + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _scale = _mm256_set1_ps(scale); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _t0 = _mm_unpacklo_epi32(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi32(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi64(_t0, _t1); + _pp1 = _mm_unpackhi_epi64(_t0, _t1); +#if !__AVXVNNIINT8__ + _pp0 = _mm_add_epi8(_pp0, _v127); + _pp1 = _mm_add_epi8(_pp1, _v127); +#endif // !__AVXVNNIINT8__ + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + + pp += 32; + p0 += B_hstep * 8; + } +#else // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + __m256 _p2 = _mm256_load_ps(p0 + 16); + __m256 _p3 = _mm256_load_ps(p0 + 24); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + _p2 = _mm256_mul_ps(_p2, _scale); + _p3 = _mm256_mul_ps(_p3, _scale); + + __m128i _pp0 = float2int8_avx(_p0, _p2); + __m128i _pp1 = float2int8_avx(_p1, _p3); + + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); + _pp0 = _mm_unpacklo_epi32(_t0, _t1); + _pp1 = _mm_unpackhi_epi32(_t0, _t1); + + _mm_store_si128((__m128i*)pp, _pp0); + _mm_store_si128((__m128i*)(pp + 16), _pp1); + + pp += 32; + p0 += B_hstep * 8; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _scale = _mm_set1_ps(scale); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + __m128 _p2 = _mm_load_ps(p0 + 8); + __m128 _p3 = _mm_load_ps(p0 + 12); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); +#if !__AVXVNNIINT8__ + _pp = _mm_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += B_hstep * 4; + } +#else // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + __m128 _p2 = _mm_load_ps(p0 + 8); + __m128 _p3 = _mm_load_ps(p0 + 12); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); + + _pp = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pp, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _pp = _mm_shuffle_epi32(_pp, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + } + if (elempack == 1) + { + __m128 _scale = _mm_set1_ps(scale); + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + B_hstep * 3); + + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + _p2 = _mm_mul_ps(_p2, _scale); + _p3 = _mm_mul_ps(_p3, _scale); + + __m128i _pp = float2int8_sse(_p0, _p1, _p2, _p3); + + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); +#if !__AVXVNNIINT8__ + _pp = _mm_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + B_hstep); + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); + __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); + __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + __m128 _p = _mm_loadu_ps(p0); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scale = _mm512_set1_ps(scale); + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_load_ps(p0); + __m512 _p1 = _mm512_load_ps(p0 + 16); + + _p0 = _mm512_mul_ps(_p0, _scale); + _p1 = _mm512_mul_ps(_p1, _scale); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + +#if __AVX512VNNI__ + __m128i _t0 = _mm_unpacklo_epi32(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi32(_pp0, _pp1); + + _t0 = _mm_add_epi8(_t0, _v127); + _t1 = _mm_add_epi8(_t1, _v127); +#else // __AVX512VNNI__ + __m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1); +#endif // __AVX512VNNI__ + + _mm_store_si128((__m128i*)pp, _t0); + _mm_store_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += B_hstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _scale = _mm256_set1_ps(scale); + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p0 = _mm256_load_ps(p0); + __m256 _p1 = _mm256_load_ps(p0 + 8); + + _p0 = _mm256_mul_ps(_p0, _scale); + _p1 = _mm256_mul_ps(_p1, _scale); + + __m128i _pp = float2int8_avx(_p0, _p1); + + _pp = _mm_shuffle_epi32(_pp, _MM_SHUFFLE(3, 1, 2, 0)); +#if __AVX512VNNI__ || __AVXVNNI__ +#if !__AVXVNNIINT8__ + _pp = _mm_add_epi8(_pp, _v127); +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + _pp = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pp, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); +#endif // __AVX512VNNI__ || __AVXVNNI__ + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += B_hstep * 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _scale = _mm_set1_ps(scale); + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_load_ps(p0); + __m128 _p1 = _mm_load_ps(p0 + 4); + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; +#if !__AVXVNNIINT8__ + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; + pp[4] += 127; + pp[5] += 127; + pp[6] += 127; + pp[7] += 127; +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __SSE2__ + if (elempack == 1) + { + int kk = 0; +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep))); + __m128 _p2 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 2))); + __m128 _p3 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 3))); + __m128 _p01 = _mm_unpacklo_ps(_p0, _p1); + __m128 _p23 = _mm_unpacklo_ps(_p2, _p3); + _p0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p01), _mm_castps_pd(_p23))); + _p1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p01), _mm_castps_pd(_p23))); + _p0 = _mm_mul_ps(_p0, _scale); + _p1 = _mm_mul_ps(_p1, _scale); +#if __AVX512VNNI__ || __AVXVNNI__ + int64_t v = float2int8_sse(_p0, _p1); + *(int64_t*)pp = v; +#if !__AVXVNNIINT8__ + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; + pp[4] += 127; + pp[5] += 127; + pp[6] += 127; + pp[7] += 127; +#endif // !__AVXVNNIINT8__ +#else // __AVX512VNNI__ || __AVXVNNI__ + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + int64_t v = float2int8_sse(_t0, _t1); + *(int64_t*)pp = v; +#endif // __AVX512VNNI__ || __AVXVNNI__ + pp += 8; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + __m128 _p0 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)p0)); + __m128 _p1 = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(p0 + B_hstep))); + __m128 _p = _mm_unpacklo_ps(_p0, _p1); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; + pp += 4; + p0 += B_hstep * 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __AVX512VNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512 _scale = _mm512_set1_ps(scale); + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p = _mm512_load_ps(p0); + + _p = _mm512_mul_ps(_p, _scale); + + __m128i _pp = float2int8_avx512(_p); + +#if __AVX512VNNI__ + _pp = _mm_add_epi8(_pp, _v127); +#endif // __AVX512VNNI__ + + _mm_store_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += B_hstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + __m256 _scale = _mm256_set1_ps(scale); + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + __m256 _p = _mm256_load_ps(p0); + _p = _mm256_mul_ps(_p, _scale); + int64_t v = float2int8_avx(_p); + *(int64_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; + pp[4] += 127; + pp[5] += 127; + pp[6] += 127; + pp[7] += 127; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp += 8; + p0 += B_hstep * 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + __m128 _scale = _mm_set1_ps(scale); + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p = _mm_load_ps(p0); + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __SSE2__ + if (elempack == 1) + { + int kk = 0; +#if __SSE2__ + __m128 _scale = _mm_set1_ps(scale); + for (; kk + 3 < max_kk; kk += 4) + { +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(B_hstep)); + + __m128 _p = _mm_i32gather_ps(p0, _vindex, sizeof(float)); +#else + __m128 _p = _mm_setr_ps(p0[0], p0[B_hstep], p0[B_hstep * 2], p0[B_hstep * 3]); +#endif + _p = _mm_mul_ps(_p, _scale); + int32_t v = float2int8_sse(_p); + *(int32_t*)pp = v; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + pp += 4; + p0 += B_hstep * 4; + } +#endif // __SSE2__ + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + unpack_output_tile_int32_to_fp32_avx2(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta, output_transpose); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const float* pC = C; + + // NCNN_LOGE("unpack_output_tile_int32_to_fp32 %d %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack, output_transpose); + + const int* pp = topT; + + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + float* p0; + if (output_transpose) + { + p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + __m512 _descale = _mm512_load_ps((const float*)descales + i + ii); + + __m512 _c0 = _mm512_set1_ps(0.f); + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = _mm512_loadu_ps(pC); + _c0 = _mm512_mul_ps(_c0, _mm512_set1_ps(beta)); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 48))); + __m512 _f4 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 64))); + __m512 _f5 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 80))); + __m512 _f6 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 96))); + __m512 _f7 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 112))); + __m512 _f8 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128))); + __m512 _f9 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 16))); + __m512 _fa = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 32))); + __m512 _fb = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 48))); + __m512 _fc = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 64))); + __m512 _fd = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 80))); + __m512 _fe = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 96))); + __m512 _ff = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 112))); + pp += 256; + + // from + // 00 11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff + // 01 12 23 30 45 56 67 74 89 9a ab b8 cd de ef fc + // 20 31 02 13 64 75 46 57 a8 b9 8a 9b ec fd ce df + // 21 32 03 10 65 76 47 54 a9 ba 8b 98 ed fe cf dc + // 08 19 2a 3b 4c 5d 6e 7f 80 91 a2 b3 c4 d5 e6 f7 + // 09 1a 2b 38 4d 5e 6f 7c 81 92 a3 b0 c5 d6 e7 f4 + // 28 39 0a 1b 6c 7d 4e 5f a0 b1 82 93 e4 f5 c6 d7 + // 29 3a 0b 18 6d 7e 4f 5c a1 b2 83 90 e5 f6 c7 d4 + // 40 51 62 73 04 15 26 37 c8 d9 ea fb 8c 9d ae bf + // 41 52 63 70 05 16 27 34 c9 da eb f8 8d 9e af bc + // 60 71 42 53 24 35 06 17 e8 f9 ca db ac bd 8e 9f + // 61 72 43 50 25 36 07 14 e9 fa cb d8 ad be 8f 9c + // 48 59 6a 7b 0c 1d 2e 3f c0 d1 e2 f3 84 95 a6 b7 + // 49 5a 6b 78 0d 1e 2f 3c c1 d2 e3 f0 85 96 a7 b4 + // 68 79 4a 5b 2c 3d 0e 1f e0 f1 c2 d3 a4 b5 86 97 + // 69 7a 4b 58 2d 3e 0f 1c e1 f2 c3 d0 a5 b6 87 94 + + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + // 08 18 28 38 48 58 68 78 88 98 a8 b8 c8 d8 e8 f8 + // 09 19 29 39 49 59 69 79 89 99 a9 b9 c9 d9 e9 f9 + // 0a 1a 2a 3a 4a 5a 6a 7a 8a 9a aa ba ca da ea fa + // 0b 1b 2b 3b 4b 5b 6b 7b 8b 9b ab bb cb db eb fb + // 0c 1c 2c 3c 4c 5c 6c 7c 8c 9c ac bc cc dc ec fc + // 0d 1d 2d 3d 4d 5d 6d 7d 8d 9d ad bd cd dd ed fd + // 0e 1e 2e 3e 4e 5e 6e 7e 8e 9e ae be ce de ee fe + // 0f 1f 2f 3f 4f 5f 6f 7f 8f 9f af bf cf df ef ff + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + _f9 = _mm512_permute_ps(_f9, _MM_SHUFFLE(2, 1, 0, 3)); + _fb = _mm512_permute_ps(_fb, _MM_SHUFFLE(2, 1, 0, 3)); + _fd = _mm512_permute_ps(_fd, _MM_SHUFFLE(2, 1, 0, 3)); + _ff = _mm512_permute_ps(_ff, _MM_SHUFFLE(2, 1, 0, 3)); + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + __m512 _tmp8 = _mm512_unpacklo_ps(_f8, _fb); + __m512 _tmp9 = _mm512_unpackhi_ps(_f8, _fb); + __m512 _tmpa = _mm512_unpacklo_ps(_fa, _f9); + __m512 _tmpb = _mm512_unpackhi_ps(_fa, _f9); + __m512 _tmpc = _mm512_unpacklo_ps(_fc, _ff); + __m512 _tmpd = _mm512_unpackhi_ps(_fc, _ff); + __m512 _tmpe = _mm512_unpacklo_ps(_fe, _fd); + __m512 _tmpf = _mm512_unpackhi_ps(_fe, _fd); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f8 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp8), _mm512_castps_pd(_tmpa))); + _f9 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp8), _mm512_castps_pd(_tmpa))); + _fa = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpb), _mm512_castps_pd(_tmp9))); + _fb = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpb), _mm512_castps_pd(_tmp9))); + _fc = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpc), _mm512_castps_pd(_tmpe))); + _fd = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpc), _mm512_castps_pd(_tmpe))); + _fe = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpf), _mm512_castps_pd(_tmpd))); + _ff = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpf), _mm512_castps_pd(_tmpd))); + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + _f9 = _mm512_permute_ps(_f9, _MM_SHUFFLE(2, 1, 0, 3)); + _fb = _mm512_permute_ps(_fb, _MM_SHUFFLE(2, 1, 0, 3)); + _fd = _mm512_permute_ps(_fd, _MM_SHUFFLE(2, 1, 0, 3)); + _ff = _mm512_permute_ps(_ff, _MM_SHUFFLE(2, 1, 0, 3)); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f8, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f1, _f9, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f2, _fa, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f3, _fb, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f8, _f0, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_f9, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_fa, _f2, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_fb, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp8 = _mm512_shuffle_f32x4(_f4, _fc, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp9 = _mm512_shuffle_f32x4(_f5, _fd, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpa = _mm512_shuffle_f32x4(_f6, _fe, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpb = _mm512_shuffle_f32x4(_f7, _ff, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpc = _mm512_shuffle_f32x4(_fc, _f4, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpd = _mm512_shuffle_f32x4(_fd, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpe = _mm512_shuffle_f32x4(_fe, _f6, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpf = _mm512_shuffle_f32x4(_ff, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp8, _MM_SHUFFLE(3, 1, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp9, _MM_SHUFFLE(3, 1, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmpa, _MM_SHUFFLE(3, 1, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmpb, _MM_SHUFFLE(3, 1, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmpc, _MM_SHUFFLE(3, 1, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp5, _tmpd, _MM_SHUFFLE(3, 1, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp6, _tmpe, _MM_SHUFFLE(3, 1, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmp7, _tmpf, _MM_SHUFFLE(3, 1, 2, 0)); + _f8 = _mm512_shuffle_f32x4(_tmp8, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _f9 = _mm512_shuffle_f32x4(_tmp9, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _fa = _mm512_shuffle_f32x4(_tmpa, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _fb = _mm512_shuffle_f32x4(_tmpb, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _fc = _mm512_shuffle_f32x4(_tmpc, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _fd = _mm512_shuffle_f32x4(_tmpd, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _fe = _mm512_shuffle_f32x4(_tmpe, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _ff = _mm512_shuffle_f32x4(_tmpf, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _f0 = _mm512_mul_ps(_f0, _descale); + _f1 = _mm512_mul_ps(_f1, _descale); + _f2 = _mm512_mul_ps(_f2, _descale); + _f3 = _mm512_mul_ps(_f3, _descale); + _f4 = _mm512_mul_ps(_f4, _descale); + _f5 = _mm512_mul_ps(_f5, _descale); + _f6 = _mm512_mul_ps(_f6, _descale); + _f7 = _mm512_mul_ps(_f7, _descale); + _f8 = _mm512_mul_ps(_f8, _descale); + _f9 = _mm512_mul_ps(_f9, _descale); + _fa = _mm512_mul_ps(_fa, _descale); + _fb = _mm512_mul_ps(_fb, _descale); + _fc = _mm512_mul_ps(_fc, _descale); + _fd = _mm512_mul_ps(_fd, _descale); + _fe = _mm512_mul_ps(_fe, _descale); + _ff = _mm512_mul_ps(_ff, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + _f8 = _mm512_add_ps(_f8, _c0); + _f9 = _mm512_add_ps(_f9, _c0); + _fa = _mm512_add_ps(_fa, _c0); + _fb = _mm512_add_ps(_fb, _c0); + _fc = _mm512_add_ps(_fc, _c0); + _fd = _mm512_add_ps(_fd, _c0); + _fe = _mm512_add_ps(_fe, _c0); + _ff = _mm512_add_ps(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + _f8 = _mm512_add_ps(_f8, _c0); + _f9 = _mm512_add_ps(_f9, _c0); + _fa = _mm512_add_ps(_fa, _c0); + _fb = _mm512_add_ps(_fb, _c0); + _fc = _mm512_add_ps(_fc, _c0); + _fd = _mm512_add_ps(_fd, _c0); + _fe = _mm512_add_ps(_fe, _c0); + _ff = _mm512_add_ps(_ff, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + __m512 _c2; + __m512 _c3; + __m512 _c4; + __m512 _c5; + __m512 _c6; + __m512 _c7; + __m512 _c8; + __m512 _c9; + __m512 _ca; + __m512 _cb; + __m512 _cc; + __m512 _cd; + __m512 _ce; + __m512 _cf; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + _c4 = _mm512_loadu_ps(pC + 64); + _c5 = _mm512_loadu_ps(pC + 80); + _c6 = _mm512_loadu_ps(pC + 96); + _c7 = _mm512_loadu_ps(pC + 112); + _c8 = _mm512_loadu_ps(pC + 128); + _c9 = _mm512_loadu_ps(pC + 128 + 16); + _ca = _mm512_loadu_ps(pC + 128 + 32); + _cb = _mm512_loadu_ps(pC + 128 + 48); + _cc = _mm512_loadu_ps(pC + 128 + 64); + _cd = _mm512_loadu_ps(pC + 128 + 80); + _ce = _mm512_loadu_ps(pC + 128 + 96); + _cf = _mm512_loadu_ps(pC + 128 + 112); + pC += 256; + } + else if (c_elempack == 8) + { + __m512 _tmp0 = _mm512_loadu_ps(pC); + __m512 _tmp1 = _mm512_loadu_ps(pC + 16); + __m512 _tmp2 = _mm512_loadu_ps(pC + 32); + __m512 _tmp3 = _mm512_loadu_ps(pC + 48); + __m512 _tmp4 = _mm512_loadu_ps(pC + 64); + __m512 _tmp5 = _mm512_loadu_ps(pC + 80); + __m512 _tmp6 = _mm512_loadu_ps(pC + 96); + __m512 _tmp7 = _mm512_loadu_ps(pC + 112); + __m512 _tmp8 = _mm512_loadu_ps(pC + c_hstep * 8); + __m512 _tmp9 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + __m512 _tmpa = _mm512_loadu_ps(pC + c_hstep * 8 + 32); + __m512 _tmpb = _mm512_loadu_ps(pC + c_hstep * 8 + 48); + __m512 _tmpc = _mm512_loadu_ps(pC + c_hstep * 8 + 64); + __m512 _tmpd = _mm512_loadu_ps(pC + c_hstep * 8 + 80); + __m512 _tmpe = _mm512_loadu_ps(pC + c_hstep * 8 + 96); + __m512 _tmpf = _mm512_loadu_ps(pC + c_hstep * 8 + 112); + + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp8, _MM_SHUFFLE(1, 0, 1, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp8, _MM_SHUFFLE(3, 2, 3, 2)); + _c2 = _mm512_shuffle_f32x4(_tmp1, _tmp9, _MM_SHUFFLE(1, 0, 1, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp1, _tmp9, _MM_SHUFFLE(3, 2, 3, 2)); + _c4 = _mm512_shuffle_f32x4(_tmp2, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + _c5 = _mm512_shuffle_f32x4(_tmp2, _tmpa, _MM_SHUFFLE(3, 2, 3, 2)); + _c6 = _mm512_shuffle_f32x4(_tmp3, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + _c7 = _mm512_shuffle_f32x4(_tmp3, _tmpb, _MM_SHUFFLE(3, 2, 3, 2)); + _c8 = _mm512_shuffle_f32x4(_tmp4, _tmpc, _MM_SHUFFLE(1, 0, 1, 0)); + _c9 = _mm512_shuffle_f32x4(_tmp4, _tmpc, _MM_SHUFFLE(3, 2, 3, 2)); + _ca = _mm512_shuffle_f32x4(_tmp5, _tmpd, _MM_SHUFFLE(1, 0, 1, 0)); + _cb = _mm512_shuffle_f32x4(_tmp5, _tmpd, _MM_SHUFFLE(3, 2, 3, 2)); + _cc = _mm512_shuffle_f32x4(_tmp6, _tmpe, _MM_SHUFFLE(1, 0, 1, 0)); + _cd = _mm512_shuffle_f32x4(_tmp6, _tmpe, _MM_SHUFFLE(3, 2, 3, 2)); + _ce = _mm512_shuffle_f32x4(_tmp7, _tmpf, _MM_SHUFFLE(1, 0, 1, 0)); + _cf = _mm512_shuffle_f32x4(_tmp7, _tmpf, _MM_SHUFFLE(3, 2, 3, 2)); + + pC += 128; + } + else if (c_elempack == 4) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + _c4 = _mm512_loadu_ps(pC + c_hstep * 4); + _c5 = _mm512_loadu_ps(pC + c_hstep * 4 + 16); + _c6 = _mm512_loadu_ps(pC + c_hstep * 4 + 32); + _c7 = _mm512_loadu_ps(pC + c_hstep * 4 + 48); + _c8 = _mm512_loadu_ps(pC + c_hstep * 8); + _c9 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + _ca = _mm512_loadu_ps(pC + c_hstep * 8 + 32); + _cb = _mm512_loadu_ps(pC + c_hstep * 8 + 48); + _cc = _mm512_loadu_ps(pC + c_hstep * 12); + _cd = _mm512_loadu_ps(pC + c_hstep * 12 + 16); + _ce = _mm512_loadu_ps(pC + c_hstep * 12 + 32); + _cf = _mm512_loadu_ps(pC + c_hstep * 12 + 48); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0, _c4, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c8, _cc, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c0, _c4, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c8, _cc, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_c1, _c5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_c9, _cd, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_c1, _c5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_c9, _cd, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp8 = _mm512_shuffle_f32x4(_c2, _c6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_ca, _ce, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpa = _mm512_shuffle_f32x4(_c2, _c6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpb = _mm512_shuffle_f32x4(_ca, _ce, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpc = _mm512_shuffle_f32x4(_c3, _c7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpd = _mm512_shuffle_f32x4(_cb, _cf, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmpe = _mm512_shuffle_f32x4(_c3, _c7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmpf = _mm512_shuffle_f32x4(_cb, _cf, _MM_SHUFFLE(3, 2, 3, 2)); + + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _c2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _c4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _c5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _c6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _c7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _c8 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _c9 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _ca = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _cb = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + _cc = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(2, 0, 2, 0)); + _cd = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(3, 1, 3, 1)); + _ce = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(2, 0, 2, 0)); + _cf = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(3, 1, 3, 1)); + + pC += 64; + } + else // if (c_elempack == 1) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + c_hstep); + _c2 = _mm512_loadu_ps(pC + c_hstep * 2); + _c3 = _mm512_loadu_ps(pC + c_hstep * 3); + _c4 = _mm512_loadu_ps(pC + c_hstep * 4); + _c5 = _mm512_loadu_ps(pC + c_hstep * 5); + _c6 = _mm512_loadu_ps(pC + c_hstep * 6); + _c7 = _mm512_loadu_ps(pC + c_hstep * 7); + _c8 = _mm512_loadu_ps(pC + c_hstep * 8); + _c9 = _mm512_loadu_ps(pC + c_hstep * 9); + _ca = _mm512_loadu_ps(pC + c_hstep * 10); + _cb = _mm512_loadu_ps(pC + c_hstep * 11); + _cc = _mm512_loadu_ps(pC + c_hstep * 12); + _cd = _mm512_loadu_ps(pC + c_hstep * 13); + _ce = _mm512_loadu_ps(pC + c_hstep * 14); + _cf = _mm512_loadu_ps(pC + c_hstep * 15); + transpose16x16_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7, _c8, _c9, _ca, _cb, _cc, _cd, _ce, _cf); + pC += 16; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + _f4 = _mm512_add_ps(_f4, _c4); + _f5 = _mm512_add_ps(_f5, _c5); + _f6 = _mm512_add_ps(_f6, _c6); + _f7 = _mm512_add_ps(_f7, _c7); + _f8 = _mm512_add_ps(_f8, _c8); + _f9 = _mm512_add_ps(_f9, _c9); + _fa = _mm512_add_ps(_fa, _ca); + _fb = _mm512_add_ps(_fb, _cb); + _fc = _mm512_add_ps(_fc, _cc); + _fd = _mm512_add_ps(_fd, _cd); + _fe = _mm512_add_ps(_fe, _ce); + _ff = _mm512_add_ps(_ff, _cf); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3, _beta, _f3); + _f4 = _mm512_fmadd_ps(_c4, _beta, _f4); + _f5 = _mm512_fmadd_ps(_c5, _beta, _f5); + _f6 = _mm512_fmadd_ps(_c6, _beta, _f6); + _f7 = _mm512_fmadd_ps(_c7, _beta, _f7); + _f8 = _mm512_fmadd_ps(_c8, _beta, _f8); + _f9 = _mm512_fmadd_ps(_c9, _beta, _f9); + _fa = _mm512_fmadd_ps(_ca, _beta, _fa); + _fb = _mm512_fmadd_ps(_cb, _beta, _fb); + _fc = _mm512_fmadd_ps(_cc, _beta, _fc); + _fd = _mm512_fmadd_ps(_cd, _beta, _fd); + _fe = _mm512_fmadd_ps(_ce, _beta, _fe); + _ff = _mm512_fmadd_ps(_cf, _beta, _ff); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + __m512 _c2 = _mm512_set1_ps(pC[2] * beta); + __m512 _c3 = _mm512_set1_ps(pC[3] * beta); + + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + + _c0 = _mm512_set1_ps(pC[4] * beta); + _c1 = _mm512_set1_ps(pC[5] * beta); + _c2 = _mm512_set1_ps(pC[6] * beta); + _c3 = _mm512_set1_ps(pC[7] * beta); + + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c1); + _f6 = _mm512_add_ps(_f6, _c2); + _f7 = _mm512_add_ps(_f7, _c3); + + _c0 = _mm512_set1_ps(pC[8] * beta); + _c1 = _mm512_set1_ps(pC[9] * beta); + _c2 = _mm512_set1_ps(pC[10] * beta); + _c3 = _mm512_set1_ps(pC[11] * beta); + + _f8 = _mm512_add_ps(_f8, _c0); + _f9 = _mm512_add_ps(_f9, _c1); + _fa = _mm512_add_ps(_fa, _c2); + _fb = _mm512_add_ps(_fb, _c3); + + _c0 = _mm512_set1_ps(pC[12] * beta); + _c1 = _mm512_set1_ps(pC[13] * beta); + _c2 = _mm512_set1_ps(pC[14] * beta); + _c3 = _mm512_set1_ps(pC[15] * beta); + + _fc = _mm512_add_ps(_fc, _c0); + _fd = _mm512_add_ps(_fd, _c1); + _fe = _mm512_add_ps(_fe, _c2); + _ff = _mm512_add_ps(_ff, _c3); + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + _f8 = _mm512_mul_ps(_f8, _alpha); + _f9 = _mm512_mul_ps(_f9, _alpha); + _fa = _mm512_mul_ps(_fa, _alpha); + _fb = _mm512_mul_ps(_fb, _alpha); + _fc = _mm512_mul_ps(_fc, _alpha); + _fd = _mm512_mul_ps(_fd, _alpha); + _fe = _mm512_mul_ps(_fe, _alpha); + _ff = _mm512_mul_ps(_ff, _alpha); + } + + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x16_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7, _f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 16 * 2, _f2); + _mm512_storeu_ps(p0 + 16 * 3, _f3); + _mm512_storeu_ps(p0 + 16 * 4, _f4); + _mm512_storeu_ps(p0 + 16 * 5, _f5); + _mm512_storeu_ps(p0 + 16 * 6, _f6); + _mm512_storeu_ps(p0 + 16 * 7, _f7); + _mm512_storeu_ps(p0 + 16 * 8, _f8); + _mm512_storeu_ps(p0 + 16 * 9, _f9); + _mm512_storeu_ps(p0 + 16 * 10, _fa); + _mm512_storeu_ps(p0 + 16 * 11, _fb); + _mm512_storeu_ps(p0 + 16 * 12, _fc); + _mm512_storeu_ps(p0 + 16 * 13, _fd); + _mm512_storeu_ps(p0 + 16 * 14, _fe); + _mm512_storeu_ps(p0 + 16 * 15, _ff); + } + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + transpose16x8_ps(_f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 16 * 2, _f2); + _mm512_storeu_ps(p0 + 16 * 3, _f3); + _mm512_storeu_ps(p0 + 16 * 4, _f4); + _mm512_storeu_ps(p0 + 16 * 5, _f5); + _mm512_storeu_ps(p0 + 16 * 6, _f6); + _mm512_storeu_ps(p0 + 16 * 7, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 2, _fa); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 3, _fb); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 4, _fc); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 5, _fd); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 6, _fe); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 7, _ff); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + transpose16x4_ps(_f8, _f9, _fa, _fb); + transpose16x4_ps(_fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0 + out_hstep * 4 + 48, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0 + out_hstep * 8 + 32, _fa); + _mm512_storeu_ps(p0 + out_hstep * 8 + 48, _fb); + _mm512_storeu_ps(p0 + out_hstep * 12, _fc); + _mm512_storeu_ps(p0 + out_hstep * 12 + 16, _fd); + _mm512_storeu_ps(p0 + out_hstep * 12 + 32, _fe); + _mm512_storeu_ps(p0 + out_hstep * 12 + 48, _ff); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 5, _f5); + _mm512_storeu_ps(p0 + out_hstep * 6, _f6); + _mm512_storeu_ps(p0 + out_hstep * 7, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 9, _f9); + _mm512_storeu_ps(p0 + out_hstep * 10, _fa); + _mm512_storeu_ps(p0 + out_hstep * 11, _fb); + _mm512_storeu_ps(p0 + out_hstep * 12, _fc); + _mm512_storeu_ps(p0 + out_hstep * 13, _fd); + _mm512_storeu_ps(p0 + out_hstep * 14, _fe); + _mm512_storeu_ps(p0 + out_hstep * 15, _ff); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 16) + { + _mm512_store_ps(p0, _f0); + _mm512_store_ps(p0 + 16, _f1); + _mm512_store_ps(p0 + 32, _f2); + _mm512_store_ps(p0 + 48, _f3); + _mm512_store_ps(p0 + 64, _f4); + _mm512_store_ps(p0 + 80, _f5); + _mm512_store_ps(p0 + 96, _f6); + _mm512_store_ps(p0 + 112, _f7); + _mm512_store_ps(p0 + 128, _f8); + _mm512_store_ps(p0 + 128 + 16, _f9); + _mm512_store_ps(p0 + 128 + 32, _fa); + _mm512_store_ps(p0 + 128 + 48, _fb); + _mm512_store_ps(p0 + 128 + 64, _fc); + _mm512_store_ps(p0 + 128 + 80, _fd); + _mm512_store_ps(p0 + 128 + 96, _fe); + _mm512_store_ps(p0 + 128 + 112, _ff); + p0 += 256; + } + if (out_elempack == 8) + { + _mm256_store_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_store_ps(p0 + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_store_ps(p0 + 16, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_store_ps(p0 + 24, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_store_ps(p0 + 32, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_store_ps(p0 + 40, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_store_ps(p0 + 48, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_store_ps(p0 + 56, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_store_ps(p0 + 64, _mm512_extractf32x8_ps(_f8, 0)); + _mm256_store_ps(p0 + 64 + 8, _mm512_extractf32x8_ps(_f9, 0)); + _mm256_store_ps(p0 + 64 + 16, _mm512_extractf32x8_ps(_fa, 0)); + _mm256_store_ps(p0 + 64 + 24, _mm512_extractf32x8_ps(_fb, 0)); + _mm256_store_ps(p0 + 64 + 32, _mm512_extractf32x8_ps(_fc, 0)); + _mm256_store_ps(p0 + 64 + 40, _mm512_extractf32x8_ps(_fd, 0)); + _mm256_store_ps(p0 + 64 + 48, _mm512_extractf32x8_ps(_fe, 0)); + _mm256_store_ps(p0 + 64 + 56, _mm512_extractf32x8_ps(_ff, 0)); + _mm256_store_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 16, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 24, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 32, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 40, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 48, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 56, _mm512_extractf32x8_ps(_f7, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 64, _mm512_extractf32x8_ps(_f8, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 64 + 8, _mm512_extractf32x8_ps(_f9, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 64 + 16, _mm512_extractf32x8_ps(_fa, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 64 + 24, _mm512_extractf32x8_ps(_fb, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 64 + 32, _mm512_extractf32x8_ps(_fc, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 64 + 40, _mm512_extractf32x8_ps(_fd, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 64 + 48, _mm512_extractf32x8_ps(_fe, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 64 + 56, _mm512_extractf32x8_ps(_ff, 1)); + p0 += 128; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f8, _f9, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_fa, _fb, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_fc, _fd, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_fe, _ff, _MM_SHUFFLE(2, 0, 2, 0)); + + __m512 _tmp8 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpa = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpb = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpc = _mm512_shuffle_f32x4(_f8, _f9, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpd = _mm512_shuffle_f32x4(_fa, _fb, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpe = _mm512_shuffle_f32x4(_fc, _fd, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpf = _mm512_shuffle_f32x4(_fe, _ff, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(2, 0, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(2, 0, 2, 0)); + + _f8 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f9 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _fa = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _fb = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _fc = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _fd = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + _fe = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(3, 1, 3, 1)); + _ff = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0 + out_hstep * 4 + 48, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0 + out_hstep * 8 + 32, _fa); + _mm512_storeu_ps(p0 + out_hstep * 8 + 48, _fb); + _mm512_storeu_ps(p0 + out_hstep * 12, _fc); + _mm512_storeu_ps(p0 + out_hstep * 12 + 16, _fd); + _mm512_storeu_ps(p0 + out_hstep * 12 + 32, _fe); + _mm512_storeu_ps(p0 + out_hstep * 12 + 48, _ff); + p0 += 64; + } + if (out_elempack == 1) + { + transpose16x16_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7, _f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 5, _f5); + _mm512_storeu_ps(p0 + out_hstep * 6, _f6); + _mm512_storeu_ps(p0 + out_hstep * 7, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 9, _f9); + _mm512_storeu_ps(p0 + out_hstep * 10, _fa); + _mm512_storeu_ps(p0 + out_hstep * 11, _fb); + _mm512_storeu_ps(p0 + out_hstep * 12, _fc); + _mm512_storeu_ps(p0 + out_hstep * 13, _fd); + _mm512_storeu_ps(p0 + out_hstep * 14, _fe); + _mm512_storeu_ps(p0 + out_hstep * 15, _ff); + p0 += 16; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 48))); + __m512 _f4 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 64))); + __m512 _f5 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 80))); + __m512 _f6 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 96))); + __m512 _f7 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 112))); + pp += 128; + + // from + // 00 11 22 33 44 55 66 77 80 91 a2 b3 c4 d5 e6 f7 + // 01 12 23 30 45 56 67 74 81 92 a3 b0 c5 d6 e7 f4 + // 20 31 02 13 64 75 46 57 a0 b1 82 93 e4 f5 c6 d7 + // 21 32 03 10 65 76 47 54 a1 b2 83 90 e5 f6 c7 d4 + // 04 15 26 37 40 51 62 73 84 95 a6 b7 c0 d1 e2 f3 + // 05 16 27 34 41 52 63 70 85 96 a7 b4 c1 d2 e3 f0 + // 24 35 06 17 60 71 42 53 a4 b5 86 97 e0 f1 c2 d3 + // 25 36 07 14 61 72 43 50 a5 b6 87 94 e1 f2 c3 d0 + // + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp5 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp6 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp7 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(2, 3, 3, 2)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(1, 3, 1, 3)); + _f5 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(1, 3, 1, 3)); + _f6 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); + _f7 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + } + + _f0 = _mm512_mul_ps(_f0, _descale); + _f1 = _mm512_mul_ps(_f1, _descale); + _f2 = _mm512_mul_ps(_f2, _descale); + _f3 = _mm512_mul_ps(_f3, _descale); + _f4 = _mm512_mul_ps(_f4, _descale); + _f5 = _mm512_mul_ps(_f5, _descale); + _f6 = _mm512_mul_ps(_f6, _descale); + _f7 = _mm512_mul_ps(_f7, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + __m512 _c2; + __m512 _c3; + __m512 _c4; + __m512 _c5; + __m512 _c6; + __m512 _c7; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + _c4 = _mm512_loadu_ps(pC + 64); + _c5 = _mm512_loadu_ps(pC + 80); + _c6 = _mm512_loadu_ps(pC + 96); + _c7 = _mm512_loadu_ps(pC + 112); + pC += 128; + } + else if (c_elempack == 8) + { + __m512 _tmp0 = _mm512_loadu_ps(pC); + __m512 _tmp1 = _mm512_loadu_ps(pC + 16); + __m512 _tmp2 = _mm512_loadu_ps(pC + 32); + __m512 _tmp3 = _mm512_loadu_ps(pC + 48); + __m512 _tmp4 = _mm512_loadu_ps(pC + c_hstep * 8); + __m512 _tmp5 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + __m512 _tmp6 = _mm512_loadu_ps(pC + c_hstep * 8 + 32); + __m512 _tmp7 = _mm512_loadu_ps(pC + c_hstep * 8 + 48); + + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(1, 0, 1, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + _c2 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(1, 0, 1, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + _c4 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + _c5 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(3, 2, 3, 2)); + _c6 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + _c7 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(3, 2, 3, 2)); + + pC += 64; + } + else if (c_elempack == 4) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + c_hstep * 4); + _c3 = _mm512_loadu_ps(pC + c_hstep * 4 + 16); + _c4 = _mm512_loadu_ps(pC + c_hstep * 8); + _c5 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + _c6 = _mm512_loadu_ps(pC + c_hstep * 12); + _c7 = _mm512_loadu_ps(pC + c_hstep * 12 + 16); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0, _c2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c4, _c6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c0, _c2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c4, _c6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_c1, _c3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_c5, _c7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_c1, _c3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_c5, _c7, _MM_SHUFFLE(3, 2, 3, 2)); + + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _c2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _c4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _c5 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _c6 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _c7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + pC += 32; + } + else // if (c_elempack == 1) + { + __m256 _cc0 = _mm256_loadu_ps(pC); + __m256 _cc1 = _mm256_loadu_ps(pC + c_hstep); + __m256 _cc2 = _mm256_loadu_ps(pC + c_hstep * 2); + __m256 _cc3 = _mm256_loadu_ps(pC + c_hstep * 3); + __m256 _cc4 = _mm256_loadu_ps(pC + c_hstep * 4); + __m256 _cc5 = _mm256_loadu_ps(pC + c_hstep * 5); + __m256 _cc6 = _mm256_loadu_ps(pC + c_hstep * 6); + __m256 _cc7 = _mm256_loadu_ps(pC + c_hstep * 7); + __m256 _cc8 = _mm256_loadu_ps(pC + c_hstep * 8); + __m256 _cc9 = _mm256_loadu_ps(pC + c_hstep * 9); + __m256 _cca = _mm256_loadu_ps(pC + c_hstep * 10); + __m256 _ccb = _mm256_loadu_ps(pC + c_hstep * 11); + __m256 _ccc = _mm256_loadu_ps(pC + c_hstep * 12); + __m256 _ccd = _mm256_loadu_ps(pC + c_hstep * 13); + __m256 _cce = _mm256_loadu_ps(pC + c_hstep * 14); + __m256 _ccf = _mm256_loadu_ps(pC + c_hstep * 15); + transpose8x8_ps(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7); + transpose8x8_ps(_cc8, _cc9, _cca, _ccb, _ccc, _ccd, _cce, _ccf); + _c0 = combine8x2_ps(_cc0, _cc8); + _c1 = combine8x2_ps(_cc1, _cc9); + _c2 = combine8x2_ps(_cc2, _cca); + _c3 = combine8x2_ps(_cc3, _ccb); + _c4 = combine8x2_ps(_cc4, _ccc); + _c5 = combine8x2_ps(_cc5, _ccd); + _c6 = combine8x2_ps(_cc6, _cce); + _c7 = combine8x2_ps(_cc7, _ccf); + pC += 8; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + _f4 = _mm512_add_ps(_f4, _c4); + _f5 = _mm512_add_ps(_f5, _c5); + _f6 = _mm512_add_ps(_f6, _c6); + _f7 = _mm512_add_ps(_f7, _c7); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3, _beta, _f3); + _f4 = _mm512_fmadd_ps(_c4, _beta, _f4); + _f5 = _mm512_fmadd_ps(_c5, _beta, _f5); + _f6 = _mm512_fmadd_ps(_c6, _beta, _f6); + _f7 = _mm512_fmadd_ps(_c7, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + __m512 _c2 = _mm512_set1_ps(pC[2] * beta); + __m512 _c3 = _mm512_set1_ps(pC[3] * beta); + + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + + _c0 = _mm512_set1_ps(pC[4] * beta); + _c1 = _mm512_set1_ps(pC[5] * beta); + _c2 = _mm512_set1_ps(pC[6] * beta); + _c3 = _mm512_set1_ps(pC[7] * beta); + + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c1); + _f6 = _mm512_add_ps(_f6, _c2); + _f7 = _mm512_add_ps(_f7, _c3); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + } + + if (output_transpose) + { + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + 64, _f4); + _mm512_storeu_ps(p0 + 80, _f5); + _mm512_storeu_ps(p0 + 96, _f6); + _mm512_storeu_ps(p0 + 112, _f7); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0 + out_hstep * 4 + 48, _f7); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 5, _f5); + _mm512_storeu_ps(p0 + out_hstep * 6, _f6); + _mm512_storeu_ps(p0 + out_hstep * 7, _f7); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 16) + { + _mm512_store_ps(p0, _f0); + _mm512_store_ps(p0 + 16, _f1); + _mm512_store_ps(p0 + 32, _f2); + _mm512_store_ps(p0 + 48, _f3); + _mm512_store_ps(p0 + 64, _f4); + _mm512_store_ps(p0 + 80, _f5); + _mm512_store_ps(p0 + 96, _f6); + _mm512_store_ps(p0 + 112, _f7); + p0 += 128; + } + if (out_elempack == 8) + { + _mm256_store_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_store_ps(p0 + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_store_ps(p0 + 16, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_store_ps(p0 + 24, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_store_ps(p0 + 32, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_store_ps(p0 + 40, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_store_ps(p0 + 48, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_store_ps(p0 + 56, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_store_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 16, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 24, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 32, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 40, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 48, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 56, _mm512_extractf32x8_ps(_f7, 1)); + p0 += 64; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f5 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _f6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _f7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + out_hstep * 4, _f2); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f3); + _mm512_storeu_ps(p0 + out_hstep * 8, _f4); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 12, _f6); + _mm512_storeu_ps(p0 + out_hstep * 12 + 16, _f7); + p0 += 32; + } + if (out_elempack == 1) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + out_hstep, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + out_hstep * 9, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + out_hstep * 10, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + out_hstep * 11, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + out_hstep * 13, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + out_hstep * 14, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + out_hstep * 15, _mm512_extractf32x8_ps(_f7, 1)); + p0 += 8; + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 48))); + pp += 64; + + // from + // 00 11 22 33 40 51 62 73 80 91 a2 b3 c0 d1 e2 f3 + // 01 12 23 30 41 52 63 70 81 92 a3 b0 c1 d2 e3 f0 + // 20 31 02 13 60 71 42 53 a0 b1 82 93 e0 f1 c2 d3 + // 21 32 03 10 61 72 43 50 a1 b2 83 90 e1 f2 c3 d0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm512_mul_ps(_f0, _descale); + _f1 = _mm512_mul_ps(_f1, _descale); + _f2 = _mm512_mul_ps(_f2, _descale); + _f3 = _mm512_mul_ps(_f3, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + __m512 _c2; + __m512 _c3; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + pC += 64; + } + else if (c_elempack == 8) + { + __m512 _cc0 = _mm512_loadu_ps(pC); + __m512 _cc1 = _mm512_loadu_ps(pC + 16); + __m512 _cc2 = _mm512_loadu_ps(pC + c_hstep * 8); + __m512 _cc3 = _mm512_loadu_ps(pC + c_hstep * 8 + 16); + _c0 = _mm512_shuffle_f32x4(_cc0, _cc2, _MM_SHUFFLE(1, 0, 1, 0)); + _c1 = _mm512_shuffle_f32x4(_cc0, _cc2, _MM_SHUFFLE(3, 2, 3, 2)); + _c2 = _mm512_shuffle_f32x4(_cc1, _cc3, _MM_SHUFFLE(1, 0, 1, 0)); + _c3 = _mm512_shuffle_f32x4(_cc1, _cc3, _MM_SHUFFLE(3, 2, 3, 2)); + pC += 32; + } + else if (c_elempack == 4) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + c_hstep * 4); + _c2 = _mm512_loadu_ps(pC + c_hstep * 8); + _c3 = _mm512_loadu_ps(pC + c_hstep * 12); + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0, _c1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c2, _c3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c0, _c1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c2, _c3, _MM_SHUFFLE(3, 2, 3, 2)); + _c0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _c1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _c2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _c3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + pC += 16; + } + else // if (c_elempack == 1) + { + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + c_hstep); + __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 2); + __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 3); + __m128 _cc4 = _mm_loadu_ps(pC + c_hstep * 4); + __m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 5); + __m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 6); + __m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 7); + __m128 _cc8 = _mm_loadu_ps(pC + c_hstep * 8); + __m128 _cc9 = _mm_loadu_ps(pC + c_hstep * 9); + __m128 _cca = _mm_loadu_ps(pC + c_hstep * 10); + __m128 _ccb = _mm_loadu_ps(pC + c_hstep * 11); + __m128 _ccc = _mm_loadu_ps(pC + c_hstep * 12); + __m128 _ccd = _mm_loadu_ps(pC + c_hstep * 13); + __m128 _cce = _mm_loadu_ps(pC + c_hstep * 14); + __m128 _ccf = _mm_loadu_ps(pC + c_hstep * 15); + _MM_TRANSPOSE4_PS(_cc0, _cc1, _cc2, _cc3); + _MM_TRANSPOSE4_PS(_cc4, _cc5, _cc6, _cc7); + _MM_TRANSPOSE4_PS(_cc8, _cc9, _cca, _ccb); + _MM_TRANSPOSE4_PS(_ccc, _ccd, _cce, _ccf); + + _c0 = combine4x4_ps(_cc0, _cc4, _cc8, _ccc); + _c1 = combine4x4_ps(_cc1, _cc5, _cc9, _ccd); + _c2 = combine4x4_ps(_cc2, _cc6, _cca, _cce); + _c3 = combine4x4_ps(_cc3, _cc7, _ccb, _ccf); + + pC += 4; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + __m512 _c2 = _mm512_set1_ps(pC[2] * beta); + __m512 _c3 = _mm512_set1_ps(pC[3] * beta); + + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + } + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + const int jj_m16 = jj % 16; + float* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_storeu_ps(p1, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p1 + 16, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p1 + 32, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p1 + 48, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p1 + 64, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p1 + 80, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p1 + 96, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p1 + 112, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p1 + 128, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p1 + 144, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p1 + 160, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p1 + 176, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p1 + 192, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p1 + 208, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p1 + 224, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p1 + 240, _mm512_extractf32x4_ps(_f3, 3)); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + const int jj_m8 = jj % 8; + float* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_storeu_ps(p1, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p1 + 8, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p1 + 16, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p1 + 24, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p1 + 32, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p1 + 40, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p1 + 48, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p1 + 56, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p1 + 64, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p1 + 72, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p1 + 80, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p1 + 88, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p1 + 96, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p1 + 104, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p1 + 112, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p1 + 120, _mm512_extractf32x4_ps(_f3, 3)); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + } + p0 += out_hstep * 4; + } + else + { + if (out_elempack == 16) + { + _mm512_store_ps(p0, _f0); + _mm512_store_ps(p0 + 16, _f1); + _mm512_store_ps(p0 + 32, _f2); + _mm512_store_ps(p0 + 48, _f3); + p0 += 64; + } + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(p0, _tmp0); + _mm512_storeu_ps(p0 + 16, _tmp1); + _mm512_storeu_ps(p0 + out_hstep * 8, _tmp2); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _tmp3); + p0 += 32; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep * 4, _f1); + _mm512_storeu_ps(p0 + out_hstep * 8, _f2); + _mm512_storeu_ps(p0 + out_hstep * 12, _f3); + p0 += 16; + } + if (out_elempack == 1) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + out_hstep, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0 + out_hstep * 9, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0 + out_hstep * 10, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0 + out_hstep * 11, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0 + out_hstep * 13, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0 + out_hstep * 14, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0 + out_hstep * 15, _mm512_extractf32x4_ps(_f3, 3)); + p0 += 4; + } + } + } + for (; jj + 1 < max_jj; jj += 2) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + pp += 32; + + // from + // 00 11 20 31 40 51 60 71 80 91 a0 b1 c0 d1 e0 f1 + // 01 10 21 30 41 50 61 70 81 90 a1 b0 c1 d0 e1 f0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + { + __m512 _tmp0 = _mm512_permute_ps(_f0, _MM_SHUFFLE(3, 1, 2, 0)); + __m512 _tmp1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(0, 2, 3, 1)); + _f0 = _mm512_unpacklo_ps(_tmp0, _tmp1); + _f1 = _mm512_unpackhi_ps(_tmp0, _tmp1); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm512_mul_ps(_f0, _descale); + _f1 = _mm512_mul_ps(_f1, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + pC += 32; + } + else if (c_elempack == 8) + { + __m512 _cc0 = _mm512_loadu_ps(pC); + __m512 _cc1 = _mm512_loadu_ps(pC + c_hstep * 8); + _c0 = _mm512_shuffle_f32x4(_cc0, _cc1, _MM_SHUFFLE(1, 0, 1, 0)); + _c1 = _mm512_shuffle_f32x4(_cc0, _cc1, _MM_SHUFFLE(3, 2, 3, 2)); + pC += 16; + } + else if (c_elempack == 4) + { + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + 4); + __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 4); + __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 4 + 4); + __m128 _cc4 = _mm_loadu_ps(pC + c_hstep * 8); + __m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 8 + 4); + __m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 12); + __m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 12 + 4); + _c0 = combine4x4_ps(_cc0, _cc2, _cc4, _cc6); + _c1 = combine4x4_ps(_cc1, _cc3, _cc5, _cc7); + pC += 8; + } + else // if (c_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(c_hstep)); + _c0 = _mm512_i32gather_ps(_vindex, pC, sizeof(float)); + _c1 = _mm512_i32gather_ps(_vindex, pC + 1, sizeof(float)); + pC += 2; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + } + + if (output_transpose) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + p0 += out_hstep * 2; + } + else + { + if (out_elempack == 16) + { + _mm512_store_ps(p0, _f0); + _mm512_store_ps(p0 + 16, _f1); + p0 += 32; + } + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + _mm512_storeu_ps(p0, _tmp0); + _mm512_storeu_ps(p0 + out_hstep * 8, _tmp1); + p0 += 16; + } + if (out_elempack == 4) + { + _mm_store_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_store_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_store_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_store_ps(p0 + out_hstep * 4 + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_store_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_store_ps(p0 + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_store_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + _mm_store_ps(p0 + out_hstep * 12 + 4, _mm512_extractf32x4_ps(_f1, 3)); + p0 += 8; + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + _mm512_i32scatter_ps(p0 + 1, _vindex, _f1, sizeof(float)); + p0 += 2; + } + } + } + for (; jj < max_jj; jj++) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + pp += 16; + + _f0 = _mm512_mul_ps(_f0, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + pC += 16; + } + else if (c_elempack == 8) + { + __m256 _cc0 = _mm256_loadu_ps(pC); + __m256 _cc1 = _mm256_loadu_ps(pC + c_hstep * 8); + _c0 = combine8x2_ps(_cc0, _cc1); + pC += 8; + } + else if (c_elempack == 4) + { + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4); + __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 8); + __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 12); + _c0 = combine4x4_ps(_cc0, _cc1, _cc2, _cc3); + pC += 4; + } + else // if (c_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(c_hstep)); + _c0 = _mm512_i32gather_ps(_vindex, pC, sizeof(float)); + pC += 1; + } + _f0 = _mm512_fmadd_ps(_c0, _mm512_set1_ps(beta), _f0); + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + _f0 = _mm512_add_ps(_f0, _c0); + pC += 1; + } + } + + _f0 = _mm512_mul_ps(_f0, _mm512_set1_ps(alpha)); + + if (output_transpose) + { + _mm512_storeu_ps(p0, _f0); + p0 += out_hstep; + } + else + { + if (out_elempack == 16) + { + _mm512_store_ps(p0, _f0); + p0 += 16; + } + if (out_elempack == 8) + { + _mm256_store_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_store_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + p0 += 8; + } + if (out_elempack == 4) + { + _mm_store_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_store_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_store_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_store_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + p0 += 4; + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + p0++; + } + } + } + } +#endif // __AVX512F__ +#if !__AVX2__ + const int* pp1 = pp + max_jj * 4; +#endif + for (; ii + 7 < max_ii; ii += 8) + { + float* p0; + if (output_transpose) + { + p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + __m256 _descale = _mm256_load_ps((const float*)descales + i + ii); +#if __AVX512F__ + __m512 _descale_avx512 = _mm512_broadcast_f32x8(_descale); +#endif + + __m256 _c0 = _mm256_set1_ps(0.f); +#if __AVX512F__ + __m512 _c0_avx512 = _mm512_set1_ps(0.f); +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = _mm256_set1_ps(pC[0] * beta); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(pC[0] * beta); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = _mm256_loadu_ps(pC); + _c0 = _mm256_mul_ps(_c0, _mm256_set1_ps(beta)); +#if __AVX512F__ + _c0_avx512 = _mm512_broadcast_f32x8(_c0); +#endif + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 48))); + __m512 _f4 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 64))); + __m512 _f5 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 80))); + __m512 _f6 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 96))); + __m512 _f7 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 112))); + pp += 128; + + // from + // 00 11 22 33 44 55 66 77 08 19 2a 3b 4c 5d 6e 7f + // 01 12 23 30 45 56 67 74 09 1a 2b 38 4d 5e 6f 7c + // 20 31 02 13 64 75 46 57 28 39 0a 1b 6c 7d 4e 5f + // 21 32 03 10 65 76 47 54 29 3a 0b 18 6d 7e 4f 5c + // 04 15 26 37 40 51 62 73 0c 1d 2e 3f 48 59 6a 7b + // 05 16 27 34 41 52 63 70 0d 1e 2f 3c 49 5a 6b 78 + // 24 35 06 17 60 71 42 53 2c 3d 0e 1f 68 79 4a 5b + // 25 36 07 14 61 72 43 50 2d 3e 0f 1c 69 7a 4b 58 + + // to + // 00 10 20 30 44 54 64 74 08 18 28 38 4c 5c 6c 7c + // 01 11 21 31 45 55 65 75 09 19 29 39 4d 5d 6d 7d + // 02 12 22 32 46 56 66 76 0a 1a 2a 3a 4e 5e 6e 7e + // 03 13 23 33 47 57 67 77 0b 1b 2b 3b 4f 5f 6f 7f + // 04 14 24 34 40 50 60 70 0c 1c 2c 3c 48 58 68 78 + // 05 15 25 35 41 51 61 71 0d 1d 2d 3d 49 59 69 79 + // 06 16 26 36 42 52 62 72 0e 1e 2e 3e 4a 5a 6a 7a + // 07 17 27 37 43 53 63 73 0f 1f 2f 3f 4b 5b 6b 7b + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp2 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp4 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp5 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp6 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp7 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(2, 3, 3, 2)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(1, 3, 1, 3)); + _f5 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _f6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(1, 3, 1, 3)); + _f7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + } + + _f0 = _mm512_mul_ps(_f0, _descale_avx512); + _f1 = _mm512_mul_ps(_f1, _descale_avx512); + _f2 = _mm512_mul_ps(_f2, _descale_avx512); + _f3 = _mm512_mul_ps(_f3, _descale_avx512); + _f4 = _mm512_mul_ps(_f4, _descale_avx512); + _f5 = _mm512_mul_ps(_f5, _descale_avx512); + _f6 = _mm512_mul_ps(_f6, _descale_avx512); + _f7 = _mm512_mul_ps(_f7, _descale_avx512); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + _f4 = _mm512_add_ps(_f4, _c0_avx512); + _f5 = _mm512_add_ps(_f5, _c0_avx512); + _f6 = _mm512_add_ps(_f6, _c0_avx512); + _f7 = _mm512_add_ps(_f7, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + _f4 = _mm512_add_ps(_f4, _c0_avx512); + _f5 = _mm512_add_ps(_f5, _c0_avx512); + _f6 = _mm512_add_ps(_f6, _c0_avx512); + _f7 = _mm512_add_ps(_f7, _c0_avx512); + } + if (broadcast_type_C == 3) + { + __m512 _c1_avx512; + __m512 _c2_avx512; + __m512 _c3_avx512; + __m512 _c4_avx512; + __m512 _c5_avx512; + __m512 _c6_avx512; + __m512 _c7_avx512; + if (c_elempack == 8) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + 16); + _c2_avx512 = _mm512_loadu_ps(pC + 32); + _c3_avx512 = _mm512_loadu_ps(pC + 48); + _c4_avx512 = _mm512_loadu_ps(pC + 64); + _c5_avx512 = _mm512_loadu_ps(pC + 80); + _c6_avx512 = _mm512_loadu_ps(pC + 96); + _c7_avx512 = _mm512_loadu_ps(pC + 112); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0_avx512, _c4_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c0_avx512, _c4_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c1_avx512, _c5_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c1_avx512, _c5_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_c2_avx512, _c6_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_c2_avx512, _c6_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_c3_avx512, _c7_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_c3_avx512, _c7_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + + _c0_avx512 = _tmp0; + _c1_avx512 = _tmp1; + _c2_avx512 = _tmp2; + _c3_avx512 = _tmp3; + _c4_avx512 = _tmp4; + _c5_avx512 = _tmp5; + _c6_avx512 = _tmp6; + _c7_avx512 = _tmp7; + + pC += 128; + } + else if (c_elempack == 4) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + 16); + _c2_avx512 = _mm512_loadu_ps(pC + 32); + _c3_avx512 = _mm512_loadu_ps(pC + 48); + _c4_avx512 = _mm512_loadu_ps(pC + c_hstep * 4); + _c5_avx512 = _mm512_loadu_ps(pC + c_hstep * 4 + 16); + _c6_avx512 = _mm512_loadu_ps(pC + c_hstep * 4 + 32); + _c7_avx512 = _mm512_loadu_ps(pC + c_hstep * 4 + 48); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0_avx512, _c2_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c0_avx512, _c2_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c1_avx512, _c3_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c1_avx512, _c3_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_c4_avx512, _c6_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_c4_avx512, _c6_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_c5_avx512, _c7_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_c5_avx512, _c7_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + + _c0_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(2, 0, 2, 0)); + _c1_avx512 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _c2_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(3, 1, 3, 1)); + _c3_avx512 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _c4_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _c5_avx512 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _c6_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(3, 1, 3, 1)); + _c7_avx512 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _c0_avx512 = _mm512_shuffle_f32x4(_c0_avx512, _c0_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c1_avx512 = _mm512_shuffle_f32x4(_c1_avx512, _c1_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c2_avx512 = _mm512_shuffle_f32x4(_c2_avx512, _c2_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c3_avx512 = _mm512_shuffle_f32x4(_c3_avx512, _c3_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c4_avx512 = _mm512_shuffle_f32x4(_c4_avx512, _c4_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c5_avx512 = _mm512_shuffle_f32x4(_c5_avx512, _c5_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c6_avx512 = _mm512_shuffle_f32x4(_c6_avx512, _c6_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + _c7_avx512 = _mm512_shuffle_f32x4(_c7_avx512, _c7_avx512, _MM_SHUFFLE(3, 1, 2, 0)); + + pC += 64; + } + else // if (c_elempack == 1) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + c_hstep); + _c2_avx512 = _mm512_loadu_ps(pC + c_hstep * 2); + _c3_avx512 = _mm512_loadu_ps(pC + c_hstep * 3); + _c4_avx512 = _mm512_loadu_ps(pC + c_hstep * 4); + _c5_avx512 = _mm512_loadu_ps(pC + c_hstep * 5); + _c6_avx512 = _mm512_loadu_ps(pC + c_hstep * 6); + _c7_avx512 = _mm512_loadu_ps(pC + c_hstep * 7); + + __m512 _tmp0 = _mm512_unpacklo_ps(_c0_avx512, _c1_avx512); + __m512 _tmp1 = _mm512_unpacklo_ps(_c2_avx512, _c3_avx512); + __m512 _tmp2 = _mm512_unpacklo_ps(_c4_avx512, _c5_avx512); + __m512 _tmp3 = _mm512_unpacklo_ps(_c6_avx512, _c7_avx512); + __m512 _tmp4 = _mm512_unpackhi_ps(_c0_avx512, _c1_avx512); + __m512 _tmp5 = _mm512_unpackhi_ps(_c2_avx512, _c3_avx512); + __m512 _tmp6 = _mm512_unpackhi_ps(_c4_avx512, _c5_avx512); + __m512 _tmp7 = _mm512_unpackhi_ps(_c6_avx512, _c7_avx512); + + _c0_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _c1_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _c2_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _c3_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _c4_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _c5_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + _c6_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _c7_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + + _tmp0 = _mm512_shuffle_f32x4(_c0_avx512, _c1_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_c2_avx512, _c3_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_c4_avx512, _c5_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_c6_avx512, _c7_avx512, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_c0_avx512, _c1_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_c2_avx512, _c3_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_c4_avx512, _c5_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_c6_avx512, _c7_avx512, _MM_SHUFFLE(3, 1, 3, 1)); + + _c0_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _c1_avx512 = _mm512_shuffle_f32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _c2_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _c3_avx512 = _mm512_shuffle_f32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _c4_avx512 = _mm512_shuffle_f32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _c5_avx512 = _mm512_shuffle_f32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _c6_avx512 = _mm512_shuffle_f32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _c7_avx512 = _mm512_shuffle_f32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + + pC += 16; + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + _f2 = _mm512_add_ps(_f2, _c2_avx512); + _f3 = _mm512_add_ps(_f3, _c3_avx512); + _f4 = _mm512_add_ps(_f4, _c4_avx512); + _f5 = _mm512_add_ps(_f5, _c5_avx512); + _f6 = _mm512_add_ps(_f6, _c6_avx512); + _f7 = _mm512_add_ps(_f7, _c7_avx512); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0_avx512, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1_avx512, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2_avx512, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3_avx512, _beta, _f3); + _f4 = _mm512_fmadd_ps(_c4_avx512, _beta, _f4); + _f5 = _mm512_fmadd_ps(_c5_avx512, _beta, _f5); + _f6 = _mm512_fmadd_ps(_c6_avx512, _beta, _f6); + _f7 = _mm512_fmadd_ps(_c7_avx512, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + __m512 _cc = _mm512_loadu_ps(pC); + _cc = _mm512_mul_ps(_cc, _mm512_set1_ps(beta)); + __m512 _cc0 = _mm512_permute_ps(_cc, _MM_SHUFFLE(0, 0, 0, 0)); + __m512 _cc1 = _mm512_permute_ps(_cc, _MM_SHUFFLE(1, 1, 1, 1)); + __m512 _cc2 = _mm512_permute_ps(_cc, _MM_SHUFFLE(2, 2, 2, 2)); + __m512 _cc3 = _mm512_permute_ps(_cc, _MM_SHUFFLE(3, 3, 3, 3)); + + _c0_avx512 = _mm512_shuffle_f32x4(_cc0, _cc0, _MM_SHUFFLE(2, 2, 0, 0)); + __m512 _c1_avx512 = _mm512_shuffle_f32x4(_cc1, _cc1, _MM_SHUFFLE(2, 2, 0, 0)); + __m512 _c2_avx512 = _mm512_shuffle_f32x4(_cc2, _cc2, _MM_SHUFFLE(2, 2, 0, 0)); + __m512 _c3_avx512 = _mm512_shuffle_f32x4(_cc3, _cc3, _MM_SHUFFLE(2, 2, 0, 0)); + __m512 _c4_avx512 = _mm512_shuffle_f32x4(_cc0, _cc0, _MM_SHUFFLE(3, 3, 1, 1)); + __m512 _c5_avx512 = _mm512_shuffle_f32x4(_cc1, _cc1, _MM_SHUFFLE(3, 3, 1, 1)); + __m512 _c6_avx512 = _mm512_shuffle_f32x4(_cc2, _cc2, _MM_SHUFFLE(3, 3, 1, 1)); + __m512 _c7_avx512 = _mm512_shuffle_f32x4(_cc3, _cc3, _MM_SHUFFLE(3, 3, 1, 1)); + + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + _f2 = _mm512_add_ps(_f2, _c2_avx512); + _f3 = _mm512_add_ps(_f3, _c3_avx512); + _f4 = _mm512_add_ps(_f4, _c4_avx512); + _f5 = _mm512_add_ps(_f5, _c5_avx512); + _f6 = _mm512_add_ps(_f6, _c6_avx512); + _f7 = _mm512_add_ps(_f7, _c7_avx512); + + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + } + + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + + _mm256_store_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_store_ps(p0 + 8, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_store_ps(p0 + 16, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_store_ps(p0 + 16 + 8, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_store_ps(p0 + 16 * 2, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_store_ps(p0 + 16 * 2 + 8, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_store_ps(p0 + 16 * 3, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_store_ps(p0 + 16 * 3 + 8, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_store_ps(p0 + 16 * 4, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_store_ps(p0 + 16 * 4 + 8, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_store_ps(p0 + 16 * 5, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_store_ps(p0 + 16 * 5 + 8, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_store_ps(p0 + 16 * 6, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_store_ps(p0 + 16 * 6 + 8, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_store_ps(p0 + 16 * 7, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_store_ps(p0 + 16 * 7 + 8, _mm512_extractf32x8_ps(_f7, 1)); + } + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 16 * 2, _f2); + _mm512_storeu_ps(p0 + 16 * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 8, _f4); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 2, _f6); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 3, _f7); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 8, _f2); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f3); + _mm512_storeu_ps(p0 + out_hstep * 12, _f6); + _mm512_storeu_ps(p0 + out_hstep * 12 + 16, _f7); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + out_hstep, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep * 9, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + out_hstep * 10, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + out_hstep * 11, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + out_hstep * 13, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + out_hstep * 14, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + out_hstep * 15, _mm512_extractf32x8_ps(_f7, 1)); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(p0, _tmp0); + _mm512_storeu_ps(p0 + 16, _tmp1); + _mm512_storeu_ps(p0 + 32, _tmp2); + _mm512_storeu_ps(p0 + 48, _tmp3); + _mm512_storeu_ps(p0 + 64, _tmp4); + _mm512_storeu_ps(p0 + 80, _tmp5); + _mm512_storeu_ps(p0 + 96, _tmp6); + _mm512_storeu_ps(p0 + 112, _tmp7); + p0 += 128; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _f7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0 + out_hstep * 4 + 48, _f7); + p0 += 64; + } + if (out_elempack == 1) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f4, _f5); + __m512 _tmp3 = _mm512_unpacklo_ps(_f6, _f7); + __m512 _tmp4 = _mm512_unpackhi_ps(_f0, _f1); + __m512 _tmp5 = _mm512_unpackhi_ps(_f2, _f3); + __m512 _tmp6 = _mm512_unpackhi_ps(_f4, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f7); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f2 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _f5 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + _f6 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp6), _mm512_castps_pd(_tmp7))); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 5, _f5); + _mm512_storeu_ps(p0 + out_hstep * 6, _f6); + _mm512_storeu_ps(p0 + out_hstep * 7, _f7); + + p0 += 16; + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { +#if __AVX2__ + __m256 _f0 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)pp)); + __m256 _f1 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 8))); + __m256 _f2 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 16))); + __m256 _f3 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 24))); + __m256 _f4 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 32))); + __m256 _f5 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 40))); + __m256 _f6 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 48))); + __m256 _f7 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 56))); + pp += 64; + + // from + // 00 11 22 33 44 55 66 77 + // 01 12 23 30 45 56 67 74 + // 20 31 02 13 64 75 46 57 + // 21 32 03 10 65 76 47 54 + // 04 15 26 37 40 51 62 73 + // 05 16 27 34 41 52 63 70 + // 24 35 06 17 60 71 42 53 + // 25 36 07 14 61 72 43 50 + + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + // 04 14 24 34 44 54 64 74 + // 05 15 25 35 45 55 65 75 + // 06 16 26 36 46 56 66 76 + // 07 17 27 37 47 57 67 77 + { + __m256 _tmp0 = _f0; + __m256 _tmp1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp2 = _f2; + __m256 _tmp3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp4 = _f4; + __m256 _tmp5 = _mm256_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp6 = _f6; + __m256 _tmp7 = _mm256_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + + _f0 = _mm256_unpacklo_ps(_tmp0, _tmp3); + _f1 = _mm256_unpackhi_ps(_tmp0, _tmp3); + _f2 = _mm256_unpacklo_ps(_tmp2, _tmp1); + _f3 = _mm256_unpackhi_ps(_tmp2, _tmp1); + _f4 = _mm256_unpacklo_ps(_tmp4, _tmp7); + _f5 = _mm256_unpackhi_ps(_tmp4, _tmp7); + _f6 = _mm256_unpacklo_ps(_tmp6, _tmp5); + _f7 = _mm256_unpackhi_ps(_tmp6, _tmp5); + + _tmp0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f0), _mm256_castps_pd(_f2))); + _tmp1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f0), _mm256_castps_pd(_f2))); + _tmp2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f3), _mm256_castps_pd(_f1))); + _tmp3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f3), _mm256_castps_pd(_f1))); + _tmp4 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f4), _mm256_castps_pd(_f6))); + _tmp5 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f4), _mm256_castps_pd(_f6))); + _tmp6 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f7), _mm256_castps_pd(_f5))); + _tmp7 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f7), _mm256_castps_pd(_f5))); + + _tmp1 = _mm256_shuffle_ps(_tmp1, _tmp1, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp3 = _mm256_shuffle_ps(_tmp3, _tmp3, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp5 = _mm256_shuffle_ps(_tmp5, _tmp5, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp7 = _mm256_shuffle_ps(_tmp7, _tmp7, _MM_SHUFFLE(2, 1, 0, 3)); + + _f0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 0)); + _f1 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 0)); + _f2 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 0)); + _f3 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 0)); + _f4 = _mm256_permute2f128_ps(_tmp4, _tmp0, _MM_SHUFFLE(0, 3, 0, 0)); + _f5 = _mm256_permute2f128_ps(_tmp5, _tmp1, _MM_SHUFFLE(0, 3, 0, 0)); + _f6 = _mm256_permute2f128_ps(_tmp6, _tmp2, _MM_SHUFFLE(0, 3, 0, 0)); + _f7 = _mm256_permute2f128_ps(_tmp7, _tmp3, _MM_SHUFFLE(0, 3, 0, 0)); + } +#else // __AVX2__ + __m256 _f0 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)pp)); + __m256 _f1 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)(pp + 8))); + __m256 _f2 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)(pp + 16))); + __m256 _f3 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)(pp + 24))); + __m256 _f4 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)pp1)); + __m256 _f5 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)(pp1 + 8))); + __m256 _f6 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)(pp1 + 16))); + __m256 _f7 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)(pp1 + 24))); + pp += 32; + pp1 += 32; + + // from + // 00 11 22 33 04 15 26 37 + // 20 31 02 13 24 35 06 17 + // 01 12 23 30 05 16 27 34 + // 21 32 03 10 25 36 07 14 + // 40 51 62 73 44 55 66 77 + // 60 71 42 53 64 75 46 57 + // 41 52 63 70 45 56 67 74 + // 61 72 43 50 65 76 47 54 + + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + // 04 14 24 34 44 54 64 74 + // 05 15 25 35 45 55 65 75 + // 06 16 26 36 46 56 66 76 + // 07 17 27 37 47 57 67 77 + { + __m256 _tmp0 = _f0; + __m256 _tmp1 = _f1; + __m256 _tmp2 = _mm256_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp4 = _f4; + __m256 _tmp5 = _f5; + __m256 _tmp6 = _mm256_shuffle_ps(_f6, _f6, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp7 = _mm256_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + + _f0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 2, 0, 0)); + _f1 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _f2 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); + _f3 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _f4 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 1)); + _f5 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _f6 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 1)); + _f7 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + + _tmp0 = _mm256_unpacklo_ps(_f0, _f3); + _tmp1 = _mm256_unpacklo_ps(_f1, _f2); + _tmp2 = _mm256_unpackhi_ps(_f1, _f2); + _tmp3 = _mm256_unpackhi_ps(_f0, _f3); + _tmp4 = _mm256_unpacklo_ps(_f4, _f7); + _tmp5 = _mm256_unpacklo_ps(_f5, _f6); + _tmp6 = _mm256_unpackhi_ps(_f5, _f6); + _tmp7 = _mm256_unpackhi_ps(_f4, _f7); + + _f0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp0), _mm256_castps_pd(_tmp1))); + _f1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp0), _mm256_castps_pd(_tmp1))); + _f2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp2), _mm256_castps_pd(_tmp3))); + _f3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp2), _mm256_castps_pd(_tmp3))); + _f4 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp4), _mm256_castps_pd(_tmp5))); + _f5 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp4), _mm256_castps_pd(_tmp5))); + _f6 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp6), _mm256_castps_pd(_tmp7))); + _f7 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp6), _mm256_castps_pd(_tmp7))); + + _f1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm256_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm256_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + } +#endif // __AVX2__ + + _f0 = _mm256_mul_ps(_f0, _descale); + _f1 = _mm256_mul_ps(_f1, _descale); + _f2 = _mm256_mul_ps(_f2, _descale); + _f3 = _mm256_mul_ps(_f3, _descale); + _f4 = _mm256_mul_ps(_f4, _descale); + _f5 = _mm256_mul_ps(_f5, _descale); + _f6 = _mm256_mul_ps(_f6, _descale); + _f7 = _mm256_mul_ps(_f7, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + _f2 = _mm256_add_ps(_f2, _c0); + _f3 = _mm256_add_ps(_f3, _c0); + _f4 = _mm256_add_ps(_f4, _c0); + _f5 = _mm256_add_ps(_f5, _c0); + _f6 = _mm256_add_ps(_f6, _c0); + _f7 = _mm256_add_ps(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + _f2 = _mm256_add_ps(_f2, _c0); + _f3 = _mm256_add_ps(_f3, _c0); + _f4 = _mm256_add_ps(_f4, _c0); + _f5 = _mm256_add_ps(_f5, _c0); + _f6 = _mm256_add_ps(_f6, _c0); + _f7 = _mm256_add_ps(_f7, _c0); + } + if (broadcast_type_C == 3) + { + __m256 _c1; + __m256 _c2; + __m256 _c3; + __m256 _c4; + __m256 _c5; + __m256 _c6; + __m256 _c7; + if (c_elempack == 8) + { + _c0 = _mm256_loadu_ps(pC); + _c1 = _mm256_loadu_ps(pC + 8); + _c2 = _mm256_loadu_ps(pC + 16); + _c3 = _mm256_loadu_ps(pC + 24); + _c4 = _mm256_loadu_ps(pC + 32); + _c5 = _mm256_loadu_ps(pC + 40); + _c6 = _mm256_loadu_ps(pC + 48); + _c7 = _mm256_loadu_ps(pC + 56); + pC += 64; + } + else if (c_elempack == 4) + { + __m256 _tmp0 = _mm256_loadu_ps(pC); + __m256 _tmp1 = _mm256_loadu_ps(pC + 8); + __m256 _tmp2 = _mm256_loadu_ps(pC + 16); + __m256 _tmp3 = _mm256_loadu_ps(pC + 24); + __m256 _tmp4 = _mm256_loadu_ps(pC + c_hstep * 4); + __m256 _tmp5 = _mm256_loadu_ps(pC + c_hstep * 4 + 8); + __m256 _tmp6 = _mm256_loadu_ps(pC + c_hstep * 4 + 16); + __m256 _tmp7 = _mm256_loadu_ps(pC + c_hstep * 4 + 24); + _c0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 2, 0, 0)); + _c1 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 1)); + _c2 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _c3 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _c4 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); + _c5 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 1)); + _c6 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _c7 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + pC += 32; + } + else // if (c_elempack == 1) + { + _c0 = _mm256_loadu_ps(pC); + _c1 = _mm256_loadu_ps(pC + c_hstep); + _c2 = _mm256_loadu_ps(pC + c_hstep * 2); + _c3 = _mm256_loadu_ps(pC + c_hstep * 3); + _c4 = _mm256_loadu_ps(pC + c_hstep * 4); + _c5 = _mm256_loadu_ps(pC + c_hstep * 5); + _c6 = _mm256_loadu_ps(pC + c_hstep * 6); + _c7 = _mm256_loadu_ps(pC + c_hstep * 7); + transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + pC += 8; + } + if (beta == 1.f) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + _f2 = _mm256_add_ps(_f2, _c2); + _f3 = _mm256_add_ps(_f3, _c3); + _f4 = _mm256_add_ps(_f4, _c4); + _f5 = _mm256_add_ps(_f5, _c5); + _f6 = _mm256_add_ps(_f6, _c6); + _f7 = _mm256_add_ps(_f7, _c7); + } + else + { + __m256 _beta = _mm256_set1_ps(beta); + _f0 = _mm256_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm256_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm256_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm256_comp_fmadd_ps(_c3, _beta, _f3); + _f4 = _mm256_comp_fmadd_ps(_c4, _beta, _f4); + _f5 = _mm256_comp_fmadd_ps(_c5, _beta, _f5); + _f6 = _mm256_comp_fmadd_ps(_c6, _beta, _f6); + _f7 = _mm256_comp_fmadd_ps(_c7, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm256_set1_ps(pC[0] * beta); + __m256 _c1 = _mm256_set1_ps(pC[1] * beta); + __m256 _c2 = _mm256_set1_ps(pC[2] * beta); + __m256 _c3 = _mm256_set1_ps(pC[3] * beta); + + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + _f2 = _mm256_add_ps(_f2, _c2); + _f3 = _mm256_add_ps(_f3, _c3); + + _c0 = _mm256_set1_ps(pC[4] * beta); + _c1 = _mm256_set1_ps(pC[5] * beta); + _c2 = _mm256_set1_ps(pC[6] * beta); + _c3 = _mm256_set1_ps(pC[7] * beta); + + _f4 = _mm256_add_ps(_f4, _c0); + _f5 = _mm256_add_ps(_f5, _c1); + _f6 = _mm256_add_ps(_f6, _c2); + _f7 = _mm256_add_ps(_f7, _c3); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m256 _alpha = _mm256_set1_ps(alpha); + _f0 = _mm256_mul_ps(_f0, _alpha); + _f1 = _mm256_mul_ps(_f1, _alpha); + _f2 = _mm256_mul_ps(_f2, _alpha); + _f3 = _mm256_mul_ps(_f3, _alpha); + _f4 = _mm256_mul_ps(_f4, _alpha); + _f5 = _mm256_mul_ps(_f5, _alpha); + _f6 = _mm256_mul_ps(_f6, _alpha); + _f7 = _mm256_mul_ps(_f7, _alpha); + } + + if (output_transpose) + { + if (out_elempack == 8) + { + transpose8x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm256_store_ps(p0, _f0); + _mm256_store_ps(p0 + 8, _f1); + _mm256_store_ps(p0 + 16, _f2); + _mm256_store_ps(p0 + 24, _f3); + _mm256_store_ps(p0 + 32, _f4); + _mm256_store_ps(p0 + 40, _f5); + _mm256_store_ps(p0 + 48, _f6); + _mm256_store_ps(p0 + 56, _f7); + } + if (out_elempack == 4) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + transpose8x4_ps(_f4, _f5, _f6, _f7); + _mm256_storeu_ps(p0, _f0); + _mm256_storeu_ps(p0 + 8, _f1); + _mm256_storeu_ps(p0 + 16, _f2); + _mm256_storeu_ps(p0 + 24, _f3); + _mm256_storeu_ps(p0 + out_hstep * 4, _f4); + _mm256_storeu_ps(p0 + out_hstep * 4 + 8, _f5); + _mm256_storeu_ps(p0 + out_hstep * 4 + 16, _f6); + _mm256_storeu_ps(p0 + out_hstep * 4 + 24, _f7); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0, _f0); + _mm256_storeu_ps(p0 + out_hstep, _f1); + _mm256_storeu_ps(p0 + out_hstep * 2, _f2); + _mm256_storeu_ps(p0 + out_hstep * 3, _f3); + _mm256_storeu_ps(p0 + out_hstep * 4, _f4); + _mm256_storeu_ps(p0 + out_hstep * 5, _f5); + _mm256_storeu_ps(p0 + out_hstep * 6, _f6); + _mm256_storeu_ps(p0 + out_hstep * 7, _f7); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 8) + { + _mm256_store_ps(p0, _f0); + _mm256_store_ps(p0 + 8, _f1); + _mm256_store_ps(p0 + 16, _f2); + _mm256_store_ps(p0 + 24, _f3); + _mm256_store_ps(p0 + 32, _f4); + _mm256_store_ps(p0 + 40, _f5); + _mm256_store_ps(p0 + 48, _f6); + _mm256_store_ps(p0 + 56, _f7); + p0 += 64; + } + if (out_elempack == 4) + { + __m256 _tmp0 = _mm256_permute2f128_ps(_f0, _f1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_f2, _f3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_f4, _f5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_f6, _f7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp4 = _mm256_permute2f128_ps(_f0, _f1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp5 = _mm256_permute2f128_ps(_f2, _f3, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp6 = _mm256_permute2f128_ps(_f4, _f5, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp7 = _mm256_permute2f128_ps(_f6, _f7, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(p0, _tmp0); + _mm256_storeu_ps(p0 + 8, _tmp1); + _mm256_storeu_ps(p0 + 16, _tmp2); + _mm256_storeu_ps(p0 + 24, _tmp3); + _mm256_storeu_ps(p0 + out_hstep * 4, _tmp4); + _mm256_storeu_ps(p0 + out_hstep * 4 + 8, _tmp5); + _mm256_storeu_ps(p0 + out_hstep * 4 + 16, _tmp6); + _mm256_storeu_ps(p0 + out_hstep * 4 + 24, _tmp7); + p0 += 32; + } + if (out_elempack == 1) + { + transpose8x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm256_storeu_ps(p0, _f0); + _mm256_storeu_ps(p0 + out_hstep, _f1); + _mm256_storeu_ps(p0 + out_hstep * 2, _f2); + _mm256_storeu_ps(p0 + out_hstep * 3, _f3); + _mm256_storeu_ps(p0 + out_hstep * 4, _f4); + _mm256_storeu_ps(p0 + out_hstep * 5, _f5); + _mm256_storeu_ps(p0 + out_hstep * 6, _f6); + _mm256_storeu_ps(p0 + out_hstep * 7, _f7); + p0 += 8; + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { +#if __AVX2__ + __m256 _f0 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)pp)); + __m256 _f1 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 8))); + __m256 _f2 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 16))); + __m256 _f3 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 24))); + pp += 32; +#else + __m256 _f01l = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)pp)); + __m256 _f23l = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)(pp + 8))); + __m256 _f01h = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)pp1)); + __m256 _f23h = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)(pp1 + 8))); + __m256 _f0 = _mm256_permute2f128_ps(_f01l, _f01h, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _f1 = _mm256_permute2f128_ps(_f01l, _f01h, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _f2 = _mm256_permute2f128_ps(_f23l, _f23h, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _f3 = _mm256_permute2f128_ps(_f23l, _f23h, _MM_SHUFFLE(0, 3, 0, 1)); + pp += 16; + pp1 += 16; +#endif + + // from + // 00 11 22 33 40 51 62 73 + // 01 12 23 30 41 52 63 70 + // 20 31 02 13 60 71 42 53 + // 21 32 03 10 61 72 43 50 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + { + _f1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp0 = _mm256_unpacklo_ps(_f0, _f3); + __m256 _tmp1 = _mm256_unpackhi_ps(_f0, _f3); + __m256 _tmp2 = _mm256_unpacklo_ps(_f2, _f1); + __m256 _tmp3 = _mm256_unpackhi_ps(_f2, _f1); + _f0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp0), _mm256_castps_pd(_tmp2))); + _f1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp0), _mm256_castps_pd(_tmp2))); + _f2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp3), _mm256_castps_pd(_tmp1))); + _f3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp3), _mm256_castps_pd(_tmp1))); + _f1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm256_mul_ps(_f0, _descale); + _f1 = _mm256_mul_ps(_f1, _descale); + _f2 = _mm256_mul_ps(_f2, _descale); + _f3 = _mm256_mul_ps(_f3, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + _f2 = _mm256_add_ps(_f2, _c0); + _f3 = _mm256_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + _f2 = _mm256_add_ps(_f2, _c0); + _f3 = _mm256_add_ps(_f3, _c0); + } + if (broadcast_type_C == 3) + { + __m256 _c1; + __m256 _c2; + __m256 _c3; + if (c_elempack == 8) + { + _c0 = _mm256_loadu_ps(pC); + _c1 = _mm256_loadu_ps(pC + 8); + _c2 = _mm256_loadu_ps(pC + 16); + _c3 = _mm256_loadu_ps(pC + 24); + pC += 32; + } + else if (c_elempack == 4) + { + __m256 _cc0 = _mm256_loadu_ps(pC); + __m256 _cc1 = _mm256_loadu_ps(pC + 8); + __m256 _cc2 = _mm256_loadu_ps(pC + c_hstep * 4); + __m256 _cc3 = _mm256_loadu_ps(pC + c_hstep * 4 + 8); + _c0 = _mm256_permute2f128_ps(_cc0, _cc2, _MM_SHUFFLE(0, 2, 0, 0)); + _c1 = _mm256_permute2f128_ps(_cc0, _cc2, _MM_SHUFFLE(0, 3, 0, 1)); + _c2 = _mm256_permute2f128_ps(_cc1, _cc3, _MM_SHUFFLE(0, 2, 0, 0)); + _c3 = _mm256_permute2f128_ps(_cc1, _cc3, _MM_SHUFFLE(0, 3, 0, 1)); + pC += 16; + } + else // if (c_elempack == 1) + { + // __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + // _c0 = _mm256_i32gather_ps(pC, _vindex, c_hstep * sizeof(float)); + // _c1 = _mm256_i32gather_ps(pC + 1, _vindex, c_hstep * sizeof(float)); + // _c2 = _mm256_i32gather_ps(pC + 2, _vindex, c_hstep * sizeof(float)); + // _c3 = _mm256_i32gather_ps(pC + 3, _vindex, c_hstep * sizeof(float)); + + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + c_hstep); + __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 2); + __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 3); + __m128 _cc4 = _mm_loadu_ps(pC + c_hstep * 4); + __m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 5); + __m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 6); + __m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 7); + _MM_TRANSPOSE4_PS(_cc0, _cc1, _cc2, _cc3); + _MM_TRANSPOSE4_PS(_cc4, _cc5, _cc6, _cc7); + + _c0 = combine4x2_ps(_cc0, _cc4); + _c1 = combine4x2_ps(_cc1, _cc5); + _c2 = combine4x2_ps(_cc2, _cc6); + _c3 = combine4x2_ps(_cc3, _cc7); + + pC += 4; + } + if (beta == 1.f) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + _f2 = _mm256_add_ps(_f2, _c2); + _f3 = _mm256_add_ps(_f3, _c3); + } + else + { + __m256 _beta = _mm256_set1_ps(beta); + _f0 = _mm256_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm256_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm256_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm256_comp_fmadd_ps(_c3, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm256_set1_ps(pC[0] * beta); + __m256 _c1 = _mm256_set1_ps(pC[1] * beta); + __m256 _c2 = _mm256_set1_ps(pC[2] * beta); + __m256 _c3 = _mm256_set1_ps(pC[3] * beta); + + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + _f2 = _mm256_add_ps(_f2, _c2); + _f3 = _mm256_add_ps(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m256 _alpha = _mm256_set1_ps(alpha); + _f0 = _mm256_mul_ps(_f0, _alpha); + _f1 = _mm256_mul_ps(_f1, _alpha); + _f2 = _mm256_mul_ps(_f2, _alpha); + _f3 = _mm256_mul_ps(_f3, _alpha); + } + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + const int jj_m16 = jj % 16; + float* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_store_ps(p1, _mm256_extractf128_ps(_f0, 0)); + _mm_store_ps(p1 + 16, _mm256_extractf128_ps(_f0, 1)); + _mm_store_ps(p1 + 32, _mm256_extractf128_ps(_f1, 0)); + _mm_store_ps(p1 + 48, _mm256_extractf128_ps(_f1, 1)); + _mm_store_ps(p1 + 64, _mm256_extractf128_ps(_f2, 0)); + _mm_store_ps(p1 + 80, _mm256_extractf128_ps(_f2, 1)); + _mm_store_ps(p1 + 96, _mm256_extractf128_ps(_f3, 0)); + _mm_store_ps(p1 + 112, _mm256_extractf128_ps(_f3, 1)); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + const int jj_m8 = jj % 8; + float* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_store_ps(p1, _mm256_extractf128_ps(_f0, 0)); + _mm_store_ps(p1 + 8, _mm256_extractf128_ps(_f0, 1)); + _mm_store_ps(p1 + 16, _mm256_extractf128_ps(_f1, 0)); + _mm_store_ps(p1 + 24, _mm256_extractf128_ps(_f1, 1)); + _mm_store_ps(p1 + 32, _mm256_extractf128_ps(_f2, 0)); + _mm_store_ps(p1 + 40, _mm256_extractf128_ps(_f2, 1)); + _mm_store_ps(p1 + 48, _mm256_extractf128_ps(_f3, 0)); + _mm_store_ps(p1 + 56, _mm256_extractf128_ps(_f3, 1)); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + _mm256_storeu_ps(p0, _f0); + _mm256_storeu_ps(p0 + 8, _f1); + _mm256_storeu_ps(p0 + 16, _f2); + _mm256_storeu_ps(p0 + 24, _f3); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0, _f0); + _mm256_storeu_ps(p0 + out_hstep, _f1); + _mm256_storeu_ps(p0 + out_hstep * 2, _f2); + _mm256_storeu_ps(p0 + out_hstep * 3, _f3); + } + p0 += out_hstep * 4; + } + else + { + if (out_elempack == 8) + { + _mm256_store_ps(p0, _f0); + _mm256_store_ps(p0 + 8, _f1); + _mm256_store_ps(p0 + 16, _f2); + _mm256_store_ps(p0 + 24, _f3); + p0 += 32; + } + if (out_elempack == 4) + { + __m256 _tmp0 = _mm256_permute2f128_ps(_f0, _f1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_f2, _f3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_f0, _f1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp3 = _mm256_permute2f128_ps(_f2, _f3, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(p0, _tmp0); + _mm256_storeu_ps(p0 + 8, _tmp1); + _mm256_storeu_ps(p0 + out_hstep * 4, _tmp2); + _mm256_storeu_ps(p0 + out_hstep * 4 + 8, _tmp3); + p0 += 16; + } + if (out_elempack == 1) + { + transpose8x4_ps(_f0, _f1, _f2, _f3); + _mm_storeu_ps(p0, _mm256_extractf128_ps(_f0, 0)); + _mm_storeu_ps(p0 + out_hstep, _mm256_extractf128_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 2, _mm256_extractf128_ps(_f1, 0)); + _mm_storeu_ps(p0 + out_hstep * 3, _mm256_extractf128_ps(_f1, 1)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm256_extractf128_ps(_f2, 0)); + _mm_storeu_ps(p0 + out_hstep * 5, _mm256_extractf128_ps(_f2, 1)); + _mm_storeu_ps(p0 + out_hstep * 6, _mm256_extractf128_ps(_f3, 0)); + _mm_storeu_ps(p0 + out_hstep * 7, _mm256_extractf128_ps(_f3, 1)); + p0 += 4; + } + } + } + for (; jj + 1 < max_jj; jj += 2) + { +#if __AVX2__ + __m256 _f0 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)pp)); + __m256 _f1 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)(pp + 8))); + pp += 16; +#else + __m256 _f01l = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)pp)); + __m256 _f01h = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)pp1)); + __m256 _f0 = _mm256_permute2f128_ps(_f01l, _f01h, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _f1 = _mm256_permute2f128_ps(_f01l, _f01h, _MM_SHUFFLE(0, 3, 0, 1)); + pp += 8; + pp1 += 8; +#endif + + // from + // 00 11 20 31 40 51 60 71 + // 01 10 21 30 41 50 61 70 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + { + __m256 _tmp0 = _mm256_shuffle_ps(_f0, _f0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256 _tmp1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 2, 3, 1)); + _f0 = _mm256_unpacklo_ps(_tmp0, _tmp1); + _f1 = _mm256_unpackhi_ps(_tmp0, _tmp1); + _f1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm256_mul_ps(_f0, _descale); + _f1 = _mm256_mul_ps(_f1, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3) + { + __m256 _c1; + if (c_elempack == 8) + { + _c0 = _mm256_loadu_ps(pC); + _c1 = _mm256_loadu_ps(pC + 8); + pC += 16; + } + else if (c_elempack == 4) + { + __m256 _cc0 = _mm256_loadu_ps(pC); + __m256 _cc1 = _mm256_loadu_ps(pC + c_hstep * 4); + _c0 = _mm256_permute2f128_ps(_cc0, _cc1, _MM_SHUFFLE(0, 2, 0, 0)); + _c1 = _mm256_permute2f128_ps(_cc0, _cc1, _MM_SHUFFLE(0, 3, 0, 1)); + pC += 8; + } + else // if (c_elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(c_hstep)); + _c0 = _mm256_i32gather_ps(pC, _vindex, sizeof(float)); + _c1 = _mm256_i32gather_ps(pC + 1, _vindex, sizeof(float)); +#else + _c0 = _mm256_setr_ps(pC[0], pC[c_hstep], pC[c_hstep * 2], pC[c_hstep * 3], pC[c_hstep * 4], pC[c_hstep * 5], pC[c_hstep * 6], pC[c_hstep * 7]); + _c1 = _mm256_setr_ps(pC[1], pC[c_hstep + 1], pC[c_hstep * 2 + 1], pC[c_hstep * 3 + 1], pC[c_hstep * 4 + 1], pC[c_hstep * 5 + 1], pC[c_hstep * 6 + 1], pC[c_hstep * 7 + 1]); +#endif + pC += 2; + } + if (beta == 1.f) + { + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + } + else + { + __m256 _beta = _mm256_set1_ps(beta); + _f0 = _mm256_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm256_comp_fmadd_ps(_c1, _beta, _f1); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm256_set1_ps(pC[0] * beta); + __m256 _c1 = _mm256_set1_ps(pC[1] * beta); + _f0 = _mm256_add_ps(_f0, _c0); + _f1 = _mm256_add_ps(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + __m256 _alpha = _mm256_set1_ps(alpha); + _f0 = _mm256_mul_ps(_f0, _alpha); + _f1 = _mm256_mul_ps(_f1, _alpha); + } + + if (output_transpose) + { + _mm256_storeu_ps(p0, _f0); + _mm256_storeu_ps(p0 + out_hstep, _f1); + p0 += out_hstep * 2; + } + else + { + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _f0); + _mm256_storeu_ps(p0 + 8, _f1); + p0 += 16; + } + if (out_elempack == 4) + { + __m256 _tmp0 = _mm256_permute2f128_ps(_f0, _f1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_f0, _f1, _MM_SHUFFLE(0, 3, 0, 1)); + + _mm256_storeu_ps(p0, _tmp0); + _mm256_storeu_ps(p0 + out_hstep * 4, _tmp1); + p0 += 8; + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(out_hstep)); + _mm256_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + _mm256_i32scatter_ps(p0 + 1, _vindex, _f1, sizeof(float)); +#else + float sum0[8]; + float sum1[8]; + _mm256_storeu_ps(sum0, _f0); + _mm256_storeu_ps(sum1, _f1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 4 + 1] = sum1[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 5 + 1] = sum1[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 6 + 1] = sum1[6]; + p0[out_hstep * 7] = sum0[7]; + p0[out_hstep * 7 + 1] = sum1[7]; +#endif // __AVX512F__ + p0 += 2; + } + } + } + for (; jj < max_jj; jj++) + { +#if __AVX2__ + __m256 _f0 = _mm256_cvtepi32_ps(_mm256_load_si256((const __m256i*)pp)); + pp += 8; +#else + __m128i _f0l = _mm_load_si128((const __m128i*)pp); + __m128i _f0h = _mm_load_si128((const __m128i*)pp1); + __m256 _f0 = _mm256_cvtepi32_ps(combine4x2_epi32(_f0l, _f0h)); + pp += 4; + pp1 += 4; +#endif + + _f0 = _mm256_mul_ps(_f0, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm256_add_ps(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm256_add_ps(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 8) + { + _c0 = _mm256_loadu_ps(pC); + pC += 8; + } + else if (c_elempack == 4) + { + __m128 _cc0 = _mm_loadu_ps(pC); + __m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4); + _c0 = combine4x2_ps(_cc0, _cc1); + pC += 4; + } + else // if (c_elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(c_hstep)); + _c0 = _mm256_i32gather_ps(pC, _vindex, sizeof(float)); +#else + _c0 = _mm256_setr_ps(pC[0], pC[c_hstep], pC[c_hstep * 2], pC[c_hstep * 3], pC[c_hstep * 4], pC[c_hstep * 5], pC[c_hstep * 6], pC[c_hstep * 7]); +#endif + pC += 1; + } + _f0 = _mm256_comp_fmadd_ps(_c0, _mm256_set1_ps(beta), _f0); + } + if (broadcast_type_C == 4) + { + _c0 = _mm256_set1_ps(pC[0] * beta); + _f0 = _mm256_add_ps(_f0, _c0); + pC += 1; + } + } + + _f0 = _mm256_mul_ps(_f0, _mm256_set1_ps(alpha)); + + if (output_transpose) + { + _mm256_storeu_ps(p0, _f0); + p0 += out_hstep; + } + else + { + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _f0); + p0 += 8; + } + if (out_elempack == 4) + { + _mm_store_ps(p0, _mm256_extractf128_ps(_f0, 0)); + _mm_store_ps(p0 + out_hstep * 4, _mm256_extractf128_ps(_f0, 1)); + p0 += 4; + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(out_hstep)); + _mm256_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); +#else + float sum0[8]; + _mm256_storeu_ps(sum0, _f0); + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 4] = sum0[4]; + p0[out_hstep * 5] = sum0[5]; + p0[out_hstep * 6] = sum0[6]; + p0[out_hstep * 7] = sum0[7]; +#endif // __AVX512F__ + p0++; + } + } + } + +#if !__AVX2__ + pp = pp1; + pp1 = pp + max_jj * 4; +#endif + } +#endif // __AVX__ + for (; ii + 3 < max_ii; ii += 4) + { + float* p0; + if (output_transpose) + { + p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + __m128 _descale = _mm_load_ps((const float*)descales + i + ii); +#if __AVX512F__ + __m512 _descale_avx512 = _mm512_broadcast_f32x4(_descale); +#endif + + __m128 _c0 = _mm_set1_ps(0.f); +#if __AVX512F__ + __m512 _c0_avx512 = _mm512_set1_ps(0.f); +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = _mm_set1_ps(pC[0] * beta); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(pC[0] * beta); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = _mm_loadu_ps(pC); + _c0 = _mm_mul_ps(_c0, _mm_set1_ps(beta)); +#if __AVX512F__ + _c0_avx512 = _mm512_broadcast_f32x4(_c0); +#endif + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 48))); + + // from + // 00 11 22 33 04 15 26 37 08 19 2a 3b 0c 1d 2e 3f + // 01 12 23 30 05 16 27 34 09 1a 2b 38 0d 1e 2f 3c + // 20 31 02 13 24 35 06 17 28 3a 0a 1b 2c 3d 0e 1f + // 21 32 03 10 25 36 07 14 29 3a 0b 18 2d 3e 0f 1c + // to + // 00 10 20 30 04 14 24 34 08 18 28 38 0c 1c 2c 3c + // 01 11 21 31 05 15 25 35 09 19 29 39 0d 1d 2d 3d + // 02 12 22 32 06 16 26 36 0a 1a 2a 3a 0e 1e 2e 3e + // 03 13 23 33 07 17 27 37 0b 1b 2b 3b 0f 1f 2f 3f + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp2 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp2))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm512_mul_ps(_f0, _descale_avx512); + _f1 = _mm512_mul_ps(_f1, _descale_avx512); + _f2 = _mm512_mul_ps(_f2, _descale_avx512); + _f3 = _mm512_mul_ps(_f3, _descale_avx512); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + } + if (broadcast_type_C == 3) + { + __m512 _c1_avx512; + __m512 _c2_avx512; + __m512 _c3_avx512; + if (c_elempack == 4) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + 16); + _c2_avx512 = _mm512_loadu_ps(pC + 32); + _c3_avx512 = _mm512_loadu_ps(pC + 48); + pC += 64; + + __m512 _tmp0 = _mm512_shuffle_f32x4(_c0_avx512, _c1_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_c2_avx512, _c3_avx512, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_c0_avx512, _c1_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_c2_avx512, _c3_avx512, _MM_SHUFFLE(3, 2, 3, 2)); + _c0_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _c1_avx512 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _c2_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _c3_avx512 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + } + else // if (c_elempack == 1) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + c_hstep); + _c2_avx512 = _mm512_loadu_ps(pC + c_hstep * 2); + _c3_avx512 = _mm512_loadu_ps(pC + c_hstep * 3); + pC += 16; + + __m512 _tmp0 = _mm512_unpacklo_ps(_c0_avx512, _c1_avx512); + __m512 _tmp1 = _mm512_unpacklo_ps(_c2_avx512, _c3_avx512); + __m512 _tmp2 = _mm512_unpackhi_ps(_c0_avx512, _c1_avx512); + __m512 _tmp3 = _mm512_unpackhi_ps(_c2_avx512, _c3_avx512); + _c0_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _c1_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _c2_avx512 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _c3_avx512 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + _f2 = _mm512_add_ps(_f2, _c2_avx512); + _f3 = _mm512_add_ps(_f3, _c3_avx512); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0_avx512, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1_avx512, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2_avx512, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3_avx512, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + __m512 _cc = _mm512_loadu_ps(pC); + _cc = _mm512_mul_ps(_cc, _mm512_set1_ps(beta)); + _c0_avx512 = _mm512_permute_ps(_cc, _MM_SHUFFLE(0, 0, 0, 0)); + __m512 _c1_avx512 = _mm512_permute_ps(_cc, _MM_SHUFFLE(1, 1, 1, 1)); + __m512 _c2_avx512 = _mm512_permute_ps(_cc, _MM_SHUFFLE(2, 2, 2, 2)); + __m512 _c3_avx512 = _mm512_permute_ps(_cc, _MM_SHUFFLE(3, 3, 3, 3)); + + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + _f2 = _mm512_add_ps(_f2, _c2_avx512); + _f3 = _mm512_add_ps(_f3, _c3_avx512); + + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + } + + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm_store_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_store_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_store_ps(p0 + 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_store_ps(p0 + 12, _mm512_extractf32x4_ps(_f3, 0)); + _mm_store_ps(p0 + 16, _mm512_extractf32x4_ps(_f0, 1)); + _mm_store_ps(p0 + 16 + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_store_ps(p0 + 16 + 8, _mm512_extractf32x4_ps(_f2, 1)); + _mm_store_ps(p0 + 16 + 12, _mm512_extractf32x4_ps(_f3, 1)); + _mm_store_ps(p0 + 32, _mm512_extractf32x4_ps(_f0, 2)); + _mm_store_ps(p0 + 32 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_store_ps(p0 + 32 + 8, _mm512_extractf32x4_ps(_f2, 2)); + _mm_store_ps(p0 + 32 + 12, _mm512_extractf32x4_ps(_f3, 2)); + _mm_store_ps(p0 + 48, _mm512_extractf32x4_ps(_f0, 3)); + _mm_store_ps(p0 + 48 + 4, _mm512_extractf32x4_ps(_f1, 3)); + _mm_store_ps(p0 + 48 + 8, _mm512_extractf32x4_ps(_f2, 3)); + _mm_store_ps(p0 + 48 + 12, _mm512_extractf32x4_ps(_f3, 3)); + } + if (out_elempack == 8) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm_store_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_store_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_store_ps(p0 + 8, _mm512_extractf32x4_ps(_f0, 1)); + _mm_store_ps(p0 + 12, _mm512_extractf32x4_ps(_f1, 1)); + _mm_store_ps(p0 + 16, _mm512_extractf32x4_ps(_f0, 2)); + _mm_store_ps(p0 + 16 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_store_ps(p0 + 16 + 8, _mm512_extractf32x4_ps(_f0, 3)); + _mm_store_ps(p0 + 16 + 12, _mm512_extractf32x4_ps(_f1, 3)); + _mm_store_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_store_ps(p0 + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_f3, 0)); + _mm_store_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x4_ps(_f2, 1)); + _mm_store_ps(p0 + out_hstep * 8 + 12, _mm512_extractf32x4_ps(_f3, 1)); + _mm_store_ps(p0 + out_hstep * 8 + 16, _mm512_extractf32x4_ps(_f2, 2)); + _mm_store_ps(p0 + out_hstep * 8 + 16 + 4, _mm512_extractf32x4_ps(_f3, 2)); + _mm_store_ps(p0 + out_hstep * 8 + 16 + 8, _mm512_extractf32x4_ps(_f2, 3)); + _mm_store_ps(p0 + out_hstep * 8 + 16 + 12, _mm512_extractf32x4_ps(_f3, 3)); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep * 4, _f1); + _mm512_storeu_ps(p0 + out_hstep * 8, _f2); + _mm512_storeu_ps(p0 + out_hstep * 12, _f3); + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + out_hstep, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 9, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + out_hstep * 10, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0 + out_hstep * 11, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + out_hstep * 13, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0 + out_hstep * 14, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0 + out_hstep * 15, _mm512_extractf32x4_ps(_f3, 3)); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + p0 += 64; + } + if (out_elempack == 1) + { + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpacklo_ps(_f2, _f3); + __m512 _tmp2 = _mm512_unpackhi_ps(_f0, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f3); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp2), _mm512_castps_pd(_tmp3))); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + p0 += 16; + } + } + + pp += 64; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); + __m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8))); + __m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12))); + __m128 _f4 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 16))); + __m128 _f5 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 20))); + __m128 _f6 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 24))); + __m128 _f7 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 28))); + + // from + // 00 11 22 33 + // 04 15 26 37 + // 20 31 02 13 + // 24 35 06 17 + // 01 12 23 30 + // 05 16 27 34 + // 21 32 03 10 + // 25 36 07 14 + // to + // 00 10 20 30 + // 01 11 21 31 + // 02 12 22 32 + // 03 13 23 33 + // 04 14 24 34 + // 05 15 25 35 + // 06 16 26 36 + // 07 17 27 37 + { + _f4 = _mm_shuffle_ps(_f4, _f4, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f6 = _mm_shuffle_ps(_f6, _f6, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f6); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f6); + __m128 _tmp2 = _mm_unpacklo_ps(_f1, _f7); + __m128 _tmp3 = _mm_unpackhi_ps(_f1, _f7); + __m128 _tmp4 = _mm_unpacklo_ps(_f2, _f4); + __m128 _tmp5 = _mm_unpackhi_ps(_f2, _f4); + __m128 _tmp6 = _mm_unpacklo_ps(_f3, _f5); + __m128 _tmp7 = _mm_unpackhi_ps(_f3, _f5); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp4))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp4))); + _f2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp5), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp5), _mm_castps_pd(_tmp1))); + _f4 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp6))); + _f5 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp6))); + _f6 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp7), _mm_castps_pd(_tmp3))); + _f7 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp7), _mm_castps_pd(_tmp3))); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm_mul_ps(_f0, _descale); + _f1 = _mm_mul_ps(_f1, _descale); + _f2 = _mm_mul_ps(_f2, _descale); + _f3 = _mm_mul_ps(_f3, _descale); + _f4 = _mm_mul_ps(_f4, _descale); + _f5 = _mm_mul_ps(_f5, _descale); + _f6 = _mm_mul_ps(_f6, _descale); + _f7 = _mm_mul_ps(_f7, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + _f4 = _mm_add_ps(_f4, _c0); + _f5 = _mm_add_ps(_f5, _c0); + _f6 = _mm_add_ps(_f6, _c0); + _f7 = _mm_add_ps(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + _f4 = _mm_add_ps(_f4, _c0); + _f5 = _mm_add_ps(_f5, _c0); + _f6 = _mm_add_ps(_f6, _c0); + _f7 = _mm_add_ps(_f7, _c0); + } + if (broadcast_type_C == 3) + { + __m128 _c1; + __m128 _c2; + __m128 _c3; + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + _c2 = _mm_loadu_ps(pC + 8); + _c3 = _mm_loadu_ps(pC + 12); + } + else // if (c_elempack == 1) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + c_hstep); + _c2 = _mm_loadu_ps(pC + c_hstep * 2); + _c3 = _mm_loadu_ps(pC + c_hstep * 3); + _MM_TRANSPOSE4_PS(_c0, _c1, _c2, _c3); + } + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm_comp_fmadd_ps(_c3, _beta, _f3); + } + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC + 16); + _c1 = _mm_loadu_ps(pC + 20); + _c2 = _mm_loadu_ps(pC + 24); + _c3 = _mm_loadu_ps(pC + 28); + pC += 32; + } + else // if (c_elempack == 1) + { + _c0 = _mm_loadu_ps(pC + 4); + _c1 = _mm_loadu_ps(pC + c_hstep + 4); + _c2 = _mm_loadu_ps(pC + c_hstep * 2 + 4); + _c3 = _mm_loadu_ps(pC + c_hstep * 3 + 4); + _MM_TRANSPOSE4_PS(_c0, _c1, _c2, _c3); + pC += 8; + } + if (beta == 1.f) + { + _f4 = _mm_add_ps(_f4, _c0); + _f5 = _mm_add_ps(_f5, _c1); + _f6 = _mm_add_ps(_f6, _c2); + _f7 = _mm_add_ps(_f7, _c3); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f4 = _mm_comp_fmadd_ps(_c0, _beta, _f4); + _f5 = _mm_comp_fmadd_ps(_c1, _beta, _f5); + _f6 = _mm_comp_fmadd_ps(_c2, _beta, _f6); + _f7 = _mm_comp_fmadd_ps(_c3, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm_set1_ps(pC[0] * beta); + __m128 _c1 = _mm_set1_ps(pC[1] * beta); + __m128 _c2 = _mm_set1_ps(pC[2] * beta); + __m128 _c3 = _mm_set1_ps(pC[3] * beta); + + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + + _c0 = _mm_set1_ps(pC[4] * beta); + _c1 = _mm_set1_ps(pC[5] * beta); + _c2 = _mm_set1_ps(pC[6] * beta); + _c3 = _mm_set1_ps(pC[7] * beta); + + _f4 = _mm_add_ps(_f4, _c0); + _f5 = _mm_add_ps(_f5, _c1); + _f6 = _mm_add_ps(_f6, _c2); + _f7 = _mm_add_ps(_f7, _c3); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + _f2 = _mm_mul_ps(_f2, _alpha); + _f3 = _mm_mul_ps(_f3, _alpha); + _f4 = _mm_mul_ps(_f4, _alpha); + _f5 = _mm_mul_ps(_f5, _alpha); + _f6 = _mm_mul_ps(_f6, _alpha); + _f7 = _mm_mul_ps(_f7, _alpha); + } + + if (output_transpose) + { +#if __AVX__ + if (out_elempack == 8) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + _MM_TRANSPOSE4_PS(_f4, _f5, _f6, _f7); + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f4); + _mm_store_ps(p0 + 8, _f1); + _mm_store_ps(p0 + 12, _f5); + _mm_store_ps(p0 + 16, _f2); + _mm_store_ps(p0 + 20, _f6); + _mm_store_ps(p0 + 24, _f3); + _mm_store_ps(p0 + 28, _f7); + } +#endif // __AVX__ + if (out_elempack == 4) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + _MM_TRANSPOSE4_PS(_f4, _f5, _f6, _f7); + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f1); + _mm_store_ps(p0 + 8, _f2); + _mm_store_ps(p0 + 12, _f3); + _mm_store_ps(p0 + out_hstep * 4, _f4); + _mm_store_ps(p0 + out_hstep * 4 + 4, _f5); + _mm_store_ps(p0 + out_hstep * 4 + 8, _f6); + _mm_store_ps(p0 + out_hstep * 4 + 12, _f7); + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + out_hstep, _f1); + _mm_storeu_ps(p0 + out_hstep * 2, _f2); + _mm_storeu_ps(p0 + out_hstep * 3, _f3); + _mm_storeu_ps(p0 + out_hstep * 4, _f4); + _mm_storeu_ps(p0 + out_hstep * 5, _f5); + _mm_storeu_ps(p0 + out_hstep * 6, _f6); + _mm_storeu_ps(p0 + out_hstep * 7, _f7); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 4) + { + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f1); + _mm_store_ps(p0 + 8, _f2); + _mm_store_ps(p0 + 12, _f3); + _mm_store_ps(p0 + 16, _f4); + _mm_store_ps(p0 + 20, _f5); + _mm_store_ps(p0 + 24, _f6); + _mm_store_ps(p0 + 28, _f7); + p0 += 32; + } + if (out_elempack == 1) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + _MM_TRANSPOSE4_PS(_f4, _f5, _f6, _f7); + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + 4, _f4); + _mm_storeu_ps(p0 + out_hstep, _f1); + _mm_storeu_ps(p0 + out_hstep + 4, _f5); + _mm_storeu_ps(p0 + out_hstep * 2, _f2); + _mm_storeu_ps(p0 + out_hstep * 2 + 4, _f6); + _mm_storeu_ps(p0 + out_hstep * 3, _f3); + _mm_storeu_ps(p0 + out_hstep * 3 + 4, _f7); + p0 += 8; + } + } + + pp += 32; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); + __m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8))); + __m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12))); + + // from + // 00 11 22 33 + // 01 12 23 30 + // 20 31 02 13 + // 21 32 03 10 + // to + // 00 10 20 30 + // 01 11 21 31 + // 02 12 22 32 + // 03 13 23 33 + { + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f3); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f3); + __m128 _tmp2 = _mm_unpacklo_ps(_f2, _f1); + __m128 _tmp3 = _mm_unpackhi_ps(_f2, _f1); + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp2))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp2))); + _f2 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp3), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp3), _mm_castps_pd(_tmp1))); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm_mul_ps(_f0, _descale); + _f1 = _mm_mul_ps(_f1, _descale); + _f2 = _mm_mul_ps(_f2, _descale); + _f3 = _mm_mul_ps(_f3, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + } + if (broadcast_type_C == 3) + { + __m128 _c1; + __m128 _c2; + __m128 _c3; + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + _c2 = _mm_loadu_ps(pC + 8); + _c3 = _mm_loadu_ps(pC + 12); + pC += 16; + } + else // if (c_elempack == 1) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + c_hstep); + _c2 = _mm_loadu_ps(pC + c_hstep * 2); + _c3 = _mm_loadu_ps(pC + c_hstep * 3); + _MM_TRANSPOSE4_PS(_c0, _c1, _c2, _c3); + pC += 4; + } + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm_comp_fmadd_ps(_c3, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm_set1_ps(pC[0] * beta); + __m128 _c1 = _mm_set1_ps(pC[1] * beta); + __m128 _c2 = _mm_set1_ps(pC[2] * beta); + __m128 _c3 = _mm_set1_ps(pC[3] * beta); + + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + _f2 = _mm_mul_ps(_f2, _alpha); + _f3 = _mm_mul_ps(_f3, _alpha); + } + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + const int jj_m16 = jj % 16; + float* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_store_ps(p1, _f0); + _mm_store_ps(p1 + 16, _f1); + _mm_store_ps(p1 + 32, _f2); + _mm_store_ps(p1 + 48, _f3); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + const int jj_m8 = jj % 8; + float* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_store_ps(p1, _f0); + _mm_store_ps(p1 + 8, _f1); + _mm_store_ps(p1 + 16, _f2); + _mm_store_ps(p1 + 24, _f3); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f1); + _mm_store_ps(p0 + 8, _f2); + _mm_store_ps(p0 + 12, _f3); + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + out_hstep, _f1); + _mm_storeu_ps(p0 + out_hstep * 2, _f2); + _mm_storeu_ps(p0 + out_hstep * 3, _f3); + } + p0 += out_hstep * 4; + } + else + { + if (out_elempack == 4) + { + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f1); + _mm_store_ps(p0 + 8, _f2); + _mm_store_ps(p0 + 12, _f3); + p0 += 16; + } + if (out_elempack == 1) + { + _MM_TRANSPOSE4_PS(_f0, _f1, _f2, _f3); + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + out_hstep, _f1); + _mm_storeu_ps(p0 + out_hstep * 2, _f2); + _mm_storeu_ps(p0 + out_hstep * 3, _f3); + p0 += 4; + } + } + + pp += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); + + // from + // 00 11 20 31 + // 01 10 21 30 + // to + // 00 10 20 30 + // 01 11 21 31 + { + __m128 _tmp0 = _mm_shuffle_ps(_f0, _f0, _MM_SHUFFLE(3, 1, 2, 0)); + __m128 _tmp1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 2, 3, 1)); + _f0 = _mm_unpacklo_ps(_tmp0, _tmp1); + _f1 = _mm_unpackhi_ps(_tmp0, _tmp1); + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm_mul_ps(_f0, _descale); + _f1 = _mm_mul_ps(_f1, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3) + { + __m128 _c1; + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + pC += 8; + } + else // if (c_elempack == 1) + { + _c0 = _mm_setr_ps(pC[0], pC[c_hstep], pC[c_hstep * 2], pC[c_hstep * 3]); + _c1 = _mm_setr_ps(pC[1], pC[c_hstep + 1], pC[c_hstep * 2 + 1], pC[c_hstep * 3 + 1]); + pC += 2; + } + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm_set1_ps(pC[0] * beta); + __m128 _c1 = _mm_set1_ps(pC[1] * beta); + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + } + + if (output_transpose) + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + out_hstep, _f1); + p0 += out_hstep * 2; + } + else + { + if (out_elempack == 4) + { + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f1); + p0 += 8; + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m128i _vindex = _mm_mullo_epi32(_mm_setr_epi32(0, 1, 2, 3), _mm_set1_epi32(out_hstep)); + _mm_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + _mm_i32scatter_ps(p0 + 1, _vindex, _f1, sizeof(float)); +#else + float sum0[4]; + float sum1[4]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum1, _f1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; +#endif // __AVX512F__ + p0 += 2; + } + } + + pp += 8; + } + for (; jj < max_jj; jj++) + { + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + + _f0 = _mm_mul_ps(_f0, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) + { + _c0 = _mm_loadu_ps(pC); + pC += 4; + } + else // if (c_elempack == 1) + { + _c0 = _mm_setr_ps(pC[0], pC[c_hstep], pC[c_hstep * 2], pC[c_hstep * 3]); + pC += 1; + } + _f0 = _mm_comp_fmadd_ps(_c0, _mm_set1_ps(beta), _f0); + } + if (broadcast_type_C == 4) + { + _c0 = _mm_set1_ps(pC[0] * beta); + _f0 = _mm_add_ps(_f0, _c0); + pC += 1; + } + } + + _f0 = _mm_mul_ps(_f0, _mm_set1_ps(alpha)); + + if (output_transpose) + { + _mm_storeu_ps(p0, _f0); + p0 += out_hstep; + } + else + { + if (out_elempack == 4) + { + _mm_store_ps(p0, _f0); + p0 += 4; + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m128i _vindex = _mm_mullo_epi32(_mm_setr_epi32(0, 1, 2, 3), _mm_set1_epi32(out_hstep)); + _mm_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); +#else + float sum0[4]; + _mm_storeu_ps(sum0, _f0); + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; +#endif // __AVX512F__ + p0++; + } + } + + pp += 4; + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + float* p0; + if (output_transpose) + { + p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + // out_elempack == 1 + p0 = (float*)top_blob + (i + ii) * out_hstep + j; + } + + const float descale0 = descales[i + ii]; + const float descale1 = descales[i + ii + 1]; +#if __SSE2__ + __m128 _descale0 = _mm_set1_ps(descale0); + __m128 _descale1 = _mm_set1_ps(descale1); +#if __AVX512F__ + __m512 _descale0_avx512 = _mm512_set1_ps(descale0); + __m512 _descale1_avx512 = _mm512_set1_ps(descale1); +#endif // __AVX512F__ +#endif + + float c0 = 0.f; + float c1 = 0.f; +#if __SSE2__ + __m128 _c0 = _mm_set1_ps(0.f); + __m128 _c1 = _mm_set1_ps(0.f); +#if __AVX512F__ + __m512 _c0_avx512 = _mm512_set1_ps(0.f); + __m512 _c1_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __SSE2__ + _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; + c1 = pC[1] * beta; +#if __SSE2__ + _c0 = _mm_set1_ps(c0); + _c1 = _mm_set1_ps(c1); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); + _c1_avx512 = _mm512_set1_ps(c1); +#endif // __AVX512F__ +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 16))); + + // 00 11 02 13 04 15 06 17 08 19 0a 1b 0c 1d 0e 1f + // 01 12 03 10 05 16 07 14 09 1a 0b 18 0d 1e 0f 1c + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f1); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f1); + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp1))); + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + + _f0 = _mm512_mul_ps(_f0, _descale0_avx512); + _f1 = _mm512_mul_ps(_f1, _descale1_avx512); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + c_hstep); + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0_avx512, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1_avx512, _beta, _f1); + } + pC += 16; + } + if (broadcast_type_C == 4) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c0_avx512 = _mm512_mul_ps(_c0_avx512, _mm512_set1_ps(beta)); + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + } + + if (output_transpose) + { + if (out_elempack == 16) + { + _mm512_store_ps(p0, _f0); + _mm512_store_ps(p0 + 16, _f1); + } + if (out_elempack == 8) + { + _mm256_store_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_store_ps(p0 + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_store_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_store_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + } + if (out_elempack == 4) + { + _mm_store_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_store_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_store_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_store_ps(p0 + out_hstep * 4 + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_store_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_store_ps(p0 + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_store_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + _mm_store_ps(p0 + out_hstep * 12 + 4, _mm512_extractf32x4_ps(_f1, 3)); + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + _mm512_i32scatter_ps(p0 + 1, _vindex, _f1, sizeof(float)); + } + p0 += out_hstep * 16; + } + else + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + p0 += 16; + } + + pp += 32; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); + __m128 _f2 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 8))); + __m128 _f3 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 12))); + + // 00 11 02 13 + // 04 15 06 17 + // 10 01 12 03 + // 14 05 16 07 + _f2 = _mm_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 3, 0, 1)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 3, 0, 1)); + + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f2); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f2); + __m128 _tmp2 = _mm_unpacklo_ps(_f1, _f3); + __m128 _tmp3 = _mm_unpackhi_ps(_f1, _f3); + + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f1 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp3))); + _f2 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f3 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp2), _mm_castps_pd(_tmp3))); + + _f2 = _mm_shuffle_ps(_f2, _f2, _MM_SHUFFLE(2, 3, 0, 1)); + _f3 = _mm_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 3, 0, 1)); + + _f0 = _mm_mul_ps(_f0, _descale0); + _f1 = _mm_mul_ps(_f1, _descale0); + _f2 = _mm_mul_ps(_f2, _descale1); + _f3 = _mm_mul_ps(_f3, _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + _f2 = _mm_add_ps(_f2, _c1); + _f3 = _mm_add_ps(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + __m128 _c2 = _mm_loadu_ps(pC + c_hstep); + __m128 _c3 = _mm_loadu_ps(pC + c_hstep + 4); + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c2); + _f3 = _mm_add_ps(_f3, _c3); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm_comp_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm_comp_fmadd_ps(_c3, _beta, _f3); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + 4); + _c0 = _mm_mul_ps(_c0, _mm_set1_ps(beta)); + _c1 = _mm_mul_ps(_c1, _mm_set1_ps(beta)); + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + _f2 = _mm_add_ps(_f2, _c0); + _f3 = _mm_add_ps(_f3, _c1); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + _f2 = _mm_mul_ps(_f2, _alpha); + _f3 = _mm_mul_ps(_f3, _alpha); + } + + if (output_transpose) + { +#if __AVX__ + if (out_elempack == 8) + { + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f1); + _mm_store_ps(p0 + 8, _f2); + _mm_store_ps(p0 + 12, _f3); + } +#endif // __AVX__ + if (out_elempack == 4) + { + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f2); + _mm_store_ps(p0 + out_hstep * 4, _f1); + _mm_store_ps(p0 + out_hstep * 4 + 4, _f3); + } + if (out_elempack == 1) + { + float sum0[4]; + float sum1[4]; + float sum2[4]; + float sum3[4]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum1, _f1); + _mm_storeu_ps(sum2, _f2); + _mm_storeu_ps(sum3, _f3); + + p0[0] = sum0[0]; + p0[1] = sum2[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum2[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum2[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum2[3]; + p0[out_hstep * 4] = sum1[0]; + p0[out_hstep * 4 + 1] = sum3[0]; + p0[out_hstep * 5] = sum1[1]; + p0[out_hstep * 5 + 1] = sum3[1]; + p0[out_hstep * 6] = sum1[2]; + p0[out_hstep * 6 + 1] = sum3[2]; + p0[out_hstep * 7] = sum1[3]; + p0[out_hstep * 7 + 1] = sum3[3]; + } + p0 += out_hstep * 8; + } + else + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + 4, _f1); + _mm_storeu_ps(p0 + out_hstep, _f2); + _mm_storeu_ps(p0 + out_hstep + 4, _f3); + p0 += 8; + } + + pp += 16; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128 _f0 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)pp)); + __m128 _f1 = _mm_cvtepi32_ps(_mm_load_si128((const __m128i*)(pp + 4))); + + // 00 11 02 13 + // 01 12 03 10 + __m128 _tmp0 = _mm_unpacklo_ps(_f0, _f1); + __m128 _tmp1 = _mm_unpackhi_ps(_f0, _f1); + + _f0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_tmp0), _mm_castps_pd(_tmp1))); + _f1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_tmp1), _mm_castps_pd(_tmp0))); + + _f1 = _mm_shuffle_ps(_f1, _f1, _MM_SHUFFLE(0, 3, 2, 1)); + + _f0 = _mm_mul_ps(_f0, _descale0); + _f1 = _mm_mul_ps(_f1, _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = _mm_loadu_ps(pC); + _c1 = _mm_loadu_ps(pC + c_hstep); + if (beta == 1.f) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c1); + } + else + { + __m128 _beta = _mm_set1_ps(beta); + _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + } + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = _mm_loadu_ps(pC); + _c0 = _mm_mul_ps(_c0, _mm_set1_ps(beta)); + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + } + + if (output_transpose) + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + const int jj_m16 = jj % 16; + float* p1 = p0 - out_hstep * jj_m16 + jj_m16; + _mm_store_ps(p1, _f0); + _mm_store_ps(p1 + 16, _f1); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + const int jj_m8 = jj % 8; + float* p1 = p0 - out_hstep * jj_m8 + jj_m8; + _mm_store_ps(p1, _f0); + _mm_store_ps(p1 + 8, _f1); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + _mm_store_ps(p0, _f0); + _mm_store_ps(p0 + 4, _f1); + } + if (out_elempack == 1) + { + float sum0[4]; + float sum1[4]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum1, _f1); + + p0[0] = sum0[0]; + p0[1] = sum1[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep + 1] = sum1[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 2 + 1] = sum1[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 3 + 1] = sum1[3]; + } + p0 += out_hstep * 4; + } + else + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + out_hstep, _f1); + p0 += 4; + } + + pp += 8; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + float f00 = pp[0] * descale0; + float f01 = pp[1] * descale0; + float f10 = pp[2] * descale1; + float f11 = pp[3] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f00 += c0; + f01 += c0; + f10 += c0; + f11 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f00 += c0; + f01 += c0; + f10 += c1; + f11 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f00 += pC[0] * beta; + f01 += pC[1] * beta; + f10 += pC[c_hstep] * beta; + f11 += pC[c_hstep + 1] * beta; + pC += 2; + } + if (broadcast_type_C == 4) + { + f00 += pC[0] * beta; + f01 += pC[1] * beta; + f10 += pC[0] * beta; + f11 += pC[1] * beta; + pC += 2; + } + } + + f00 *= alpha; + f01 *= alpha; + f10 *= alpha; + f11 *= alpha; + + if (output_transpose) + { + p0[0] = f00; + p0[1] = f10; + p0[out_hstep] = f01; + p0[out_hstep + 1] = f11; + p0 += out_hstep * 2; + } + else + { + p0[0] = f00; + p0[1] = f01; + p0[out_hstep] = f10; + p0[out_hstep + 1] = f11; + p0 += 2; + } + + pp += 4; + } + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += pC[0] * beta; + f1 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + f1 *= alpha; + + if (output_transpose) + { + p0[0] = f0; + p0[1] = f1; + p0 += out_hstep; + } + else + { + p0[0] = f0; + p0[out_hstep] = f1; + p0++; + } + + pp += 2; + } + } + for (; ii < max_ii; ii += 1) + { + float* p0; + if (output_transpose) + { + p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + // out_elempack == 1 + p0 = (float*)top_blob + (i + ii) * out_hstep + j; + } + + const float descale = descales[i + ii]; +#if __SSE2__ + __m128 _descale = _mm_set1_ps(descale); +#if __AVX512F__ + __m512 _descale_avx512 = _mm512_set1_ps(descale); +#endif // __AVX512F__ +#endif + + float c0 = 0.f; +#if __SSE2__ + __m128 _c0 = _mm_set1_ps(0.f); +#if __AVX512F__ + __m512 _c0_avx512 = _mm512_set1_ps(0.f); +#endif // __AVX512F__ +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __SSE2__ + _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; +#if __SSE2__ + _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)pp)), _descale_avx512); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0_avx512 = _mm512_loadu_ps(pC); + _f0 = _mm512_fmadd_ps(_c0_avx512, _mm512_set1_ps(beta), _f0); + pC += 16; + } + } + + if (alpha != 1.f) + { + _f0 = _mm512_mul_ps(_f0, _mm512_set1_ps(alpha)); + } + + if (output_transpose) + { + if (out_hstep == 1) + { + _mm512_storeu_ps(p0, _f0); + } + else + { + if (out_elempack == 16) + { + _mm512_storeu_ps(p0, _f0); + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + } + } + p0 += out_hstep * 16; + } + else + { + _mm512_storeu_ps(p0, _f0); + p0 += 16; + } + + pp += 16; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)pp)), _descale); + __m128 _f1 = _mm_mul_ps(_mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)(pp + 4))), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + _f1 = _mm_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = _mm_loadu_ps(pC); + __m128 _c1 = _mm_loadu_ps(pC + 4); + _f0 = _mm_comp_fmadd_ps(_c0, _mm_set1_ps(beta), _f0); + _f1 = _mm_comp_fmadd_ps(_c1, _mm_set1_ps(beta), _f1); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m128 _alpha = _mm_set1_ps(alpha); + _f0 = _mm_mul_ps(_f0, _alpha); + _f1 = _mm_mul_ps(_f1, _alpha); + } + + if (output_transpose) + { + if (out_hstep == 1) + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + 4, _f1); + } + else + { +#if __AVX__ + if (out_elempack == 8) + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + 4, _f1); + } +#endif // __AVX__ + if (out_elempack == 4) + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + out_hstep * 4, _f1); + } + if (out_elempack == 1) + { + float sum0[4]; + float sum1[4]; + _mm_storeu_ps(sum0, _f0); + _mm_storeu_ps(sum1, _f1); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + p0[out_hstep * 4] = sum1[0]; + p0[out_hstep * 5] = sum1[1]; + p0[out_hstep * 6] = sum1[2]; + p0[out_hstep * 7] = sum1[3]; + } + } + p0 += out_hstep * 8; + } + else + { + _mm_storeu_ps(p0, _f0); + _mm_storeu_ps(p0 + 4, _f1); + p0 += 8; + } + + pp += 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm_add_ps(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = _mm_loadu_ps(pC); + _f0 = _mm_comp_fmadd_ps(_c0, _mm_set1_ps(beta), _f0); + pC += 4; + } + } + + _f0 = _mm_mul_ps(_f0, _mm_set1_ps(alpha)); + + if (output_transpose) + { + if (out_hstep == 1) + { + _mm_storeu_ps(p0, _f0); + } + else + { +#if !(defined(__x86_64__) || defined(_M_X64)) +#if __AVX__ +#if __AVX512F__ + if (out_elempack == 16) + { + _mm_storeu_ps(p0 - (jj % 16) / 4 * out_hstep * 4 + (jj % 16) / 4 * 4, _f0); + } +#endif // __AVX512F__ + if (out_elempack == 8) + { + _mm_storeu_ps(p0 - (jj % 8) / 4 * out_hstep * 4 + (jj % 8) / 4 * 4, _f0); + } +#endif // __AVX__ +#endif // !(defined(__x86_64__) || defined(_M_X64)) + if (out_elempack == 4) + { + _mm_storeu_ps(p0, _f0); + } + if (out_elempack == 1) + { + float sum0[4]; + _mm_storeu_ps(sum0, _f0); + + p0[0] = sum0[0]; + p0[out_hstep] = sum0[1]; + p0[out_hstep * 2] = sum0[2]; + p0[out_hstep * 3] = sum0[3]; + } + } + p0 += out_hstep * 4; + } + else + { + _mm_storeu_ps(p0, _f0); + p0 += 4; + } + + pp += 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + float f0 = pp[0] * descale; + float f1 = pp[1] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[1] * beta; + pC += 2; + } + } + + f0 *= alpha; + f1 *= alpha; + + if (output_transpose) + { + p0[0] = f0; + p0[out_hstep] = f1; + p0 += out_hstep * 2; + } + else + { + p0[0] = f0; + p0[1] = f1; + p0 += 2; + } + + pp += 2; + } + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + f0 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = f0; + + if (output_transpose) + { + p0 += out_hstep; + } + else + { + p0++; + } + + pp += 1; + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + gemm_transB_packed_tile_int8_avx512vnni(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + gemm_transB_packed_tile_int8_avxvnniint8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + gemm_transB_packed_tile_int8_avxvnni(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + gemm_transB_packed_tile_int8_avx2(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_xop()) + { + gemm_transB_packed_tile_int8_xop(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + + // actually we only depend the global k==0 condition + (void)i; + (void)j; + + const signed char* pAT = AT_tile; + const signed char* pBT = BT_tile; + + int* outptr = topT_tile; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const signed char* pB = pBT; + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + __m512i _sum8; + __m512i _sum9; + __m512i _suma; + __m512i _sumb; + __m512i _sumc; + __m512i _sumd; + __m512i _sume; + __m512i _sumf; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + _sum8 = _mm512_setzero_si512(); + _sum9 = _mm512_setzero_si512(); + _suma = _mm512_setzero_si512(); + _sumb = _mm512_setzero_si512(); + _sumc = _mm512_setzero_si512(); + _sumd = _mm512_setzero_si512(); + _sume = _mm512_setzero_si512(); + _sumf = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + _sum8 = _mm512_load_si512((const __m512i*)(outptr + 128)); + _sum9 = _mm512_load_si512((const __m512i*)(outptr + 128 + 16)); + _suma = _mm512_load_si512((const __m512i*)(outptr + 128 + 32)); + _sumb = _mm512_load_si512((const __m512i*)(outptr + 128 + 48)); + _sumc = _mm512_load_si512((const __m512i*)(outptr + 128 + 64)); + _sumd = _mm512_load_si512((const __m512i*)(outptr + 128 + 80)); + _sume = _mm512_load_si512((const __m512i*)(outptr + 128 + 96)); + _sumf = _mm512_load_si512((const __m512i*)(outptr + 128 + 112)); + } + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pA3 = _mm512_shuffle_epi32(_pA2, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_i32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm512_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm512_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm512_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm512_dpbusd_epi32(_sum7, _pB3, _pA1); + _sum8 = _mm512_dpbusd_epi32(_sum8, _pB0, _pA2); + _sum9 = _mm512_dpbusd_epi32(_sum9, _pB1, _pA2); + _suma = _mm512_dpbusd_epi32(_suma, _pB0, _pA3); + _sumb = _mm512_dpbusd_epi32(_sumb, _pB1, _pA3); + _sumc = _mm512_dpbusd_epi32(_sumc, _pB2, _pA2); + _sumd = _mm512_dpbusd_epi32(_sumd, _pB3, _pA2); + _sume = _mm512_dpbusd_epi32(_sume, _pB2, _pA3); + _sumf = _mm512_dpbusd_epi32(_sumf, _pB3, _pA3); + pA += 64; + pB += 64; + } + if (max_kk >= 4) + { + __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + __m512i _w_shift2 = _mm512_shuffle_i32x4(_w_shift0, _w_shift0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _w_shift3 = _mm512_shuffle_epi32(_w_shift2, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift1); + _sum4 = _mm512_sub_epi32(_sum4, _w_shift0); + _sum5 = _mm512_sub_epi32(_sum5, _w_shift0); + _sum6 = _mm512_sub_epi32(_sum6, _w_shift1); + _sum7 = _mm512_sub_epi32(_sum7, _w_shift1); + _sum8 = _mm512_sub_epi32(_sum8, _w_shift2); + _sum9 = _mm512_sub_epi32(_sum9, _w_shift2); + _suma = _mm512_sub_epi32(_suma, _w_shift3); + _sumb = _mm512_sub_epi32(_sumb, _w_shift3); + _sumc = _mm512_sub_epi32(_sumc, _w_shift2); + _sumd = _mm512_sub_epi32(_sumd, _w_shift2); + _sume = _mm512_sub_epi32(_sume, _w_shift3); + _sumf = _mm512_sub_epi32(_sumf, _w_shift3); + pA += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 4567 89ab cdef + // 2301 6745 ab89 efcd + // 4567 0123 cdef 89ab + // 6745 2301 efcd ab89 + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pA3 = _mm512_shuffle_epi32(_pA2, _MM_PERM_BADC); + + // 0123 4567 89ab cdef + // 1230 5674 9ab8 defc + // 89ab cdef 0123 4567 + // 9ab8 defc 1230 5674 + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_i32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA0, _pB2); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA0, _pB3); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA1, _pB3); + _sum8 = _mm512_comp_dpwssd_epi32(_sum8, _pA2, _pB0); + _sum9 = _mm512_comp_dpwssd_epi32(_sum9, _pA2, _pB1); + _suma = _mm512_comp_dpwssd_epi32(_suma, _pA3, _pB0); + _sumb = _mm512_comp_dpwssd_epi32(_sumb, _pA3, _pB1); + _sumc = _mm512_comp_dpwssd_epi32(_sumc, _pA2, _pB2); + _sumd = _mm512_comp_dpwssd_epi32(_sumd, _pA2, _pB3); + _sume = _mm512_comp_dpwssd_epi32(_sume, _pA3, _pB2); + _sumf = _mm512_comp_dpwssd_epi32(_sumf, _pA3, _pB3); + + pA += 32; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pA2 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pA3 = _mm256_shuffle_epi32(_pA2, _MM_SHUFFLE(2, 3, 0, 1)); + + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1))); + _sum4 = _mm512_add_epi32(_sum4, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB2))); + _sum5 = _mm512_add_epi32(_sum5, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB3))); + _sum6 = _mm512_add_epi32(_sum6, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB2))); + _sum7 = _mm512_add_epi32(_sum7, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB3))); + _sum8 = _mm512_add_epi32(_sum8, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB0))); + _sum9 = _mm512_add_epi32(_sum9, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB1))); + _suma = _mm512_add_epi32(_suma, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB0))); + _sumb = _mm512_add_epi32(_sumb, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB1))); + _sumc = _mm512_add_epi32(_sumc, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB2))); + _sumd = _mm512_add_epi32(_sumd, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB3))); + _sume = _mm512_add_epi32(_sume, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB2))); + _sumf = _mm512_add_epi32(_sumf, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB3))); + + pA += 16; + pB += 16; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + _mm512_store_si512((__m512i*)(outptr + 128), _sum8); + _mm512_store_si512((__m512i*)(outptr + 128 + 16), _sum9); + _mm512_store_si512((__m512i*)(outptr + 128 + 32), _suma); + _mm512_store_si512((__m512i*)(outptr + 128 + 48), _sumb); + _mm512_store_si512((__m512i*)(outptr + 128 + 64), _sumc); + _mm512_store_si512((__m512i*)(outptr + 128 + 80), _sumd); + _mm512_store_si512((__m512i*)(outptr + 128 + 96), _sume); + _mm512_store_si512((__m512i*)(outptr + 128 + 112), _sumf); + outptr += 256; + } + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + } + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pB0 = combine8x2_epi32(_pB, _pB); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm512_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm512_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm512_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm512_dpbusd_epi32(_sum7, _pB3, _pA1); + pA += 64; + pB += 32; + } + if (max_kk >= 4) + { + __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift1); + _sum4 = _mm512_sub_epi32(_sum4, _w_shift0); + _sum5 = _mm512_sub_epi32(_sum5, _w_shift0); + _sum6 = _mm512_sub_epi32(_sum6, _w_shift1); + _sum7 = _mm512_sub_epi32(_sum7, _w_shift1); + pA += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m256i _pBB = _mm256_cvtepi8_epi16(_pB); + + // 0123 4567 89ab cdef + // 2301 6745 ab89 efcd + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + + // 0123 4567 0123 4567 + // 1230 5674 1230 5674 + // 4567 0123 4567 0123 + // 5674 1230 5674 1230 + __m512i _pB0 = combine8x2_epi32(_pBB, _pBB); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA0, _pB2); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA0, _pB3); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA1, _pB3); + + pA += 32; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); + + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + + __m256i _pB0 = combine4x2_epi32(_pB, _pB); + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1))); + _sum4 = _mm512_add_epi32(_sum4, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB2))); + _sum5 = _mm512_add_epi32(_sum5, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB3))); + _sum6 = _mm512_add_epi32(_sum6, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB2))); + _sum7 = _mm512_add_epi32(_sum7, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB3))); + + pA += 16; + pB += 8; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + outptr += 128; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + } + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pB)); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA1); + pA += 64; + pB += 16; + } + if (max_kk >= 4) + { + __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift1); + pA += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 4567 89ab cdef + // 2301 6745 ab89 efcd + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + + // 0123 0123 0123 0123 + // 1230 1230 1230 1230 + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + + pA += 32; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1))); + + pA += 16; + pB += 4; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; + } + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + } + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pB)[0])); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA); + pA += 64; + pB += 8; + } + if (max_kk >= 4) + { + __m512i _w_shift = _mm512_loadu_si512((const __m512i*)pA); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift); + pA += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 4567 89ab cdef + + // 0101 0101 0101 0101 + // 1010 1010 1010 1010 + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + + pA += 32; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_set1_epi16(((const short*)pB)[0]); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 1, 0, 1)), _MM_SHUFFLE(0, 1, 0, 1)); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + + pA += 16; + pB += 2; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + outptr += 32; + } + for (; jj < max_jj; jj += 1) + { + const signed char* pA = pAT; + + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + } + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB = _mm512_set1_epi32(((const int*)pB)[0]); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB, _pA); + pA += 64; + pB += 4; + } + if (max_kk >= 4) + { + __m512i _w_shift = _mm512_loadu_si512((const __m512i*)pA); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift); + pA += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pBBBB = _mm512_cvtepi8_epi16(_pB); + + // 0xxx0xxx0xxx0xxx -> 00000000... + __m512i _pB0 = _mm512_shuffle_epi32(_pBBBB, _MM_PERM_AAAA); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + + pA += 32; + pB += 2; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m256i _pB = _mm256_set1_epi16(pB[0]); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB))); + + pA += 16; + pB += 1; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + outptr += 16; + } + + pAT += max_kk * 16; +#if __AVX512VNNI__ + if (max_kk >= 4) + { + pAT += 64; + } +#endif // __AVX512VNNI__ + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + const signed char* pB = pBT; + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + } + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA00 = combine8x2_epi32(_pA0, _pA0); + __m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA00); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA00); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA11); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA11); + _sum4 = _mm512_dpbusd_epi32(_sum4, _pB2, _pA00); + _sum5 = _mm512_dpbusd_epi32(_sum5, _pB3, _pA00); + _sum6 = _mm512_dpbusd_epi32(_sum6, _pB2, _pA11); + _sum7 = _mm512_dpbusd_epi32(_sum7, _pB3, _pA11); + pA += 32; + pB += 64; + } + if (max_kk >= 4) + { + __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); + __m512i _w_shift00 = combine8x2_epi32(_w_shift0, _w_shift0); + __m512i _w_shift11 = _mm512_shuffle_epi32(_w_shift00, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift00); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift00); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift11); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift11); + _sum4 = _mm512_sub_epi32(_sum4, _w_shift00); + _sum5 = _mm512_sub_epi32(_sum5, _w_shift00); + _sum6 = _mm512_sub_epi32(_sum6, _w_shift11); + _sum7 = _mm512_sub_epi32(_sum7, _w_shift11); + pA += 32; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 4567 0123 4567 + // 2301 6745 2301 6745 + __m512i _pA00 = combine8x2_epi32(_pA0, _pA0); + __m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC); + + // 0123 4567 89ab cdef + // 1230 5674 9ab8 defc + // 4567 0123 cdef 89ab + // 5674 1230 defc 9ab8 + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA00, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA00, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA11, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA11, _pB1); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA00, _pB2); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA00, _pB3); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA11, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA11, _pB3); + + pA += 16; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + _pA = _mm_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + __m256i _pA00 = combine4x2_epi32(_pA, _pA); + __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); + + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB1))); + _sum4 = _mm512_add_epi32(_sum4, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB2))); + _sum5 = _mm512_add_epi32(_sum5, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB3))); + _sum6 = _mm512_add_epi32(_sum6, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB2))); + _sum7 = _mm512_add_epi32(_sum7, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB3))); + + pA += 8; + pB += 16; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + + outptr += 128; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; + __m256i _sum4; + __m256i _sum5; + __m256i _sum6; + __m256i _sum7; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + _sum4 = _mm256_setzero_si256(); + _sum5 = _mm256_setzero_si256(); + _sum6 = _mm256_setzero_si256(); + _sum7 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); + _sum4 = _mm256_load_si256((const __m256i*)(outptr + 32)); + _sum5 = _mm256_load_si256((const __m256i*)(outptr + 40)); + _sum6 = _mm256_load_si256((const __m256i*)(outptr + 48)); + _sum7 = _mm256_load_si256((const __m256i*)(outptr + 56)); + } + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); +#if __AVXVNNIINT8__ + _sum0 = _mm256_dpbssd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm256_dpbssd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm256_dpbssd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm256_dpbssd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm256_dpbssd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm256_dpbssd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm256_dpbssd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm256_dpbssd_epi32(_sum7, _pB3, _pA1); +#else // __AVXVNNIINT8__ + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm256_comp_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm256_comp_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm256_comp_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm256_comp_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm256_comp_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm256_comp_dpbusd_epi32(_sum7, _pB3, _pA1); +#endif // __AVXVNNIINT8__ + pA += 32; + pB += 32; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _w_shift1 = _mm256_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm256_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm256_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm256_sub_epi32(_sum3, _w_shift1); + _sum4 = _mm256_sub_epi32(_sum4, _w_shift0); + _sum5 = _mm256_sub_epi32(_sum5, _w_shift0); + _sum6 = _mm256_sub_epi32(_sum6, _w_shift1); + _sum7 = _mm256_sub_epi32(_sum7, _w_shift1); + pA += 32; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX512F__ + __m128i _pA = _mm_load_si128((const __m128i*)pA); +#else + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); +#endif + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 0123 4567 + // 2301 6745 + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + + // 0123 4567 + // 1230 5674 + // 4567 0123 + // 5674 1230 + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + _sum4 = _mm256_comp_dpwssd_epi32(_sum4, _pA0, _pB2); + _sum5 = _mm256_comp_dpwssd_epi32(_sum5, _pA0, _pB3); + _sum6 = _mm256_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm256_comp_dpwssd_epi32(_sum7, _pA1, _pB3); + + pA += 16; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA0 = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB0 = _mm_loadl_epi64((const __m128i*)pB); + + _pA0 = _mm_cvtepi8_epi16(_pA0); + _pB0 = _mm_cvtepi8_epi16(_pB0); + + __m128i _pA1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pA0, _MM_SHUFFLE(1, 0, 3, 2)), _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB3 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm256_add_epi32(_sum1, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm256_add_epi32(_sum2, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm256_add_epi32(_sum3, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1))); + _sum4 = _mm256_add_epi32(_sum4, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB2))); + _sum5 = _mm256_add_epi32(_sum5, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB3))); + _sum6 = _mm256_add_epi32(_sum6, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB2))); + _sum7 = _mm256_add_epi32(_sum7, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB3))); + + pA += 8; + pB += 8; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 16), _sum2); + _mm256_store_si256((__m256i*)(outptr + 24), _sum3); + _mm256_store_si256((__m256i*)(outptr + 32), _sum4); + _mm256_store_si256((__m256i*)(outptr + 40), _sum5); + _mm256_store_si256((__m256i*)(outptr + 48), _sum6); + _mm256_store_si256((__m256i*)(outptr + 56), _sum7); + + outptr += 64; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); + } + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + __m256i _pB0 = combine4x2_epi32(_pB, _pB); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); +#if __AVXVNNIINT8__ + _sum0 = _mm256_dpbssd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm256_dpbssd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm256_dpbssd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm256_dpbssd_epi32(_sum3, _pB1, _pA1); +#else // __AVXVNNIINT8__ + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm256_comp_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm256_comp_dpbusd_epi32(_sum3, _pB1, _pA1); +#endif // __AVXVNNIINT8__ + pA += 32; + pB += 16; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _w_shift1 = _mm256_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm256_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm256_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm256_sub_epi32(_sum3, _w_shift1); + pA += 32; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX512F__ + __m128i _pA = _mm_load_si128((const __m128i*)pA); +#else + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); +#endif + __m128i _pB = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 0123 4567 + // 2301 6745 + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + + // 0123 0123 + // 1230 1230 + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + + pA += 16; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA0 = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB0 = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + + _pA0 = _mm_cvtepi8_epi16(_pA0); + _pB0 = _mm_cvtepi8_epi16(_pB0); + + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + + __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm256_add_epi32(_sum1, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm256_add_epi32(_sum2, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm256_add_epi32(_sum3, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1))); + + pA += 8; + pB += 4; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 16), _sum2); + _mm256_store_si256((__m256i*)(outptr + 24), _sum3); + + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + } + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); +#if __AVXVNNIINT8__ + _sum0 = _mm256_dpbssd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm256_dpbssd_epi32(_sum1, _pB1, _pA); +#else // __AVXVNNIINT8__ + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA); +#endif // __AVXVNNIINT8__ + pA += 32; + pB += 8; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m256i _w_shift = _mm256_loadu_si256((const __m256i*)pA); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift); + _sum1 = _mm256_sub_epi32(_sum1, _w_shift); + pA += 32; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX512F__ + __m128i _pA = _mm_load_si128((const __m128i*)pA); +#else + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); +#endif + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 0123 4567 + + // 0101 0101 + // 1010 1010 + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + + pA += 16; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB0 = _mm_set1_epi16(((const short*)pB)[0]); + + _pA = _mm_cvtepi8_epi16(_pA); + _pB0 = _mm_cvtepi8_epi16(_pB0); + + // 01234567 + + // 01010101 + // 10101010 + __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 1, 0, 1)), _MM_SHUFFLE(0, 1, 0, 1)); + + _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB0))); + _sum1 = _mm256_add_epi32(_sum1, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB1))); + + pA += 8; + pB += 2; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + const signed char* pA = pAT; + + __m256i _sum0; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + } + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); +#if __AVXVNNIINT8__ + _sum0 = _mm256_dpbssd_epi32(_sum0, _pB, _pA); +#else // __AVXVNNIINT8__ +#if __AVX512VNNI__ && _MSC_VER < 1932 + // old msvc crash here --- nihui + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepu8_epi16(_pB); + __m512i _s0 = _mm512_madd_epi16(_pA0, _pB0); + __m256i _s1 = _mm256_hadd_epi32(_mm512_extracti32x8_epi32(_s0, 0), _mm512_extracti32x8_epi32(_s0, 1)); + _sum0 = _mm256_add_epi32(_sum0, _mm256_permute4x64_epi64(_s1, _MM_SHUFFLE(3, 1, 2, 0))); +#else + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB, _pA); +#endif +#endif // __AVXVNNIINT8__ + pA += 32; + pB += 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m256i _w_shift = _mm256_loadu_si256((const __m256i*)pA); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift); + pA += 32; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX512F__ + __m128i _pA = _mm_load_si128((const __m128i*)pA); +#else + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); +#endif + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pBB = _mm256_cvtepi8_epi16(_pB); + + // 0xxx0xxx -> 00000000 11111111 + __m256i _pB0 = _mm256_shuffle_epi32(_pBB, _MM_SHUFFLE(0, 0, 0, 0)); + + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + + pA += 16; + pB += 2; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_set1_epi16(pB[0]); + + _pA = _mm_cvtepi8_epi16(_pA); + + _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB))); + + pA += 8; + pB += 1; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + + outptr += 8; + } + + pAT += max_kk * 8; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + pAT += 32; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* pB = pBT; + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); + } + + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pA)); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA1); + pA += 16; + pB += 64; + } + if (max_kk >= 4) + { + __m512i _w_shift0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pA)); + __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift1); + pA += 16; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 0123 0123 0123 + // 2301 2301 2301 2301 + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + + // 0123 4567 89ab cdef + // 1230 5674 9ab8 defc + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + + pA += 8; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1))); + + pA += 4; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + + outptr += 64; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* pA = pAT; + +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; +#else // __AVX2__ + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; + __m128i _sum4; + __m128i _sum5; + __m128i _sum6; + __m128i _sum7; +#endif // __AVX2__ + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); +#else // __AVX2__ + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + _sum4 = _mm_setzero_si128(); + _sum5 = _mm_setzero_si128(); + _sum6 = _mm_setzero_si128(); + _sum7 = _mm_setzero_si128(); +#endif // __AVX2__ + } + else + { +#if __AVX2__ + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); +#else // __AVX2__ + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); + _sum4 = _mm_load_si128((const __m128i*)(outptr + 16)); + _sum5 = _mm_load_si128((const __m128i*)(outptr + 20)); + _sum6 = _mm_load_si128((const __m128i*)(outptr + 24)); + _sum7 = _mm_load_si128((const __m128i*)(outptr + 28)); +#endif // __AVX2__ + } + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m256i _pB01 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA00 = combine4x2_epi32(_pA0, _pA0); + __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB23 = _mm256_shuffle_epi32(_pB01, _MM_SHUFFLE(0, 3, 2, 1)); +#if __AVXVNNIINT8__ + _sum0 = _mm256_dpbssd_epi32(_sum0, _pB01, _pA00); + _sum1 = _mm256_dpbssd_epi32(_sum1, _pB01, _pA11); + _sum2 = _mm256_dpbssd_epi32(_sum2, _pB23, _pA00); + _sum3 = _mm256_dpbssd_epi32(_sum3, _pB23, _pA11); +#else // __AVXVNNIINT8__ + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB01, _pA00); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB01, _pA11); + _sum2 = _mm256_comp_dpbusd_epi32(_sum2, _pB23, _pA00); + _sum3 = _mm256_comp_dpbusd_epi32(_sum3, _pB23, _pA11); +#endif // __AVXVNNIINT8__ + pA += 16; + pB += 32; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m128i _w_shift0 = _mm_loadu_si128((const __m128i*)pA); + __m256i _w_shift00 = combine4x2_epi32(_w_shift0, _w_shift0); + __m256i _w_shift11 = _mm256_shuffle_epi32(_w_shift00, _MM_SHUFFLE(1, 0, 3, 2)); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift00); + _sum1 = _mm256_sub_epi32(_sum1, _w_shift11); + _sum2 = _mm256_sub_epi32(_sum2, _w_shift00); + _sum3 = _mm256_sub_epi32(_sum3, _w_shift11); + pA += 16; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + +#if __AVX2__ + __m256i _pA00 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB01 = _mm256_cvtepi8_epi16(_pB); + + __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB23 = _mm256_shuffle_epi32(_pB01, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA00, _pB01); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA11, _pB01); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA00, _pB23); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA11, _pB23); +#else // __AVX2__ +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); +#endif + __m128i _extpB = _mm_cmpgt_epi8(_mm_setzero_si128(), _pB); + __m128i _pBl = _mm_unpacklo_epi8(_pB, _extpB); + __m128i _pBh = _mm_unpackhi_epi8(_pB, _extpB); + + // 0123 + // 2301 + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(1, 0, 3, 2)); + + // 0123 + // 4567 + // 1230 + // 5674 + __m128i _pB0 = _pBl; + __m128i _pB1 = _pBh; + __m128i _pB2 = _mm_shuffle_epi32(_pBl, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pBh, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + _sum4 = _mm_comp_dpwssd_epi32(_sum4, _pA0, _pB2); + _sum5 = _mm_comp_dpwssd_epi32(_sum5, _pA0, _pB3); + _sum6 = _mm_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm_comp_dpwssd_epi32(_sum7, _pA1, _pB3); +#endif // __AVX2__ + pA += 8; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + +#if __AVX2__ + // 01230123 + // 23012301 + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + + // 01234567 + // 12305674 + __m128i _pB0 = _pB; + __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0)); + __m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0)); + __m256i _s2 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1)); + __m256i _s3 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1)); + + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); +#else // __AVX2__ +#if __XOP__ + // 00112233 + // 22330011 + __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + + // 00112233 + // 44556677 + // 1.2.3.0. + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_unpackhi_epi16(_pB, _pB); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maccd_epi16(_pA1, _pB0, _sum2); + _sum3 = _mm_maccd_epi16(_pA1, _pB1, _sum3); + _sum4 = _mm_maccd_epi16(_pA0, _pB2, _sum4); + _sum5 = _mm_maccd_epi16(_pA0, _pB3, _sum5); + _sum6 = _mm_maccd_epi16(_pA1, _pB2, _sum6); + _sum7 = _mm_maccd_epi16(_pA1, _pB3, _sum7); +#else + // 01230123 + // 23012301 + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + + // 01234567 + // 12305674 + __m128i _pB01 = _pB; + __m128i _pB23 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB01); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB01); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB01); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB01); + __m128i _sl2 = _mm_mullo_epi16(_pA0, _pB23); + __m128i _sh2 = _mm_mulhi_epi16(_pA0, _pB23); + __m128i _sl3 = _mm_mullo_epi16(_pA1, _pB23); + __m128i _sh3 = _mm_mulhi_epi16(_pA1, _pB23); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + __m128i _s4 = _mm_unpacklo_epi16(_sl2, _sh2); + __m128i _s5 = _mm_unpackhi_epi16(_sl2, _sh2); + __m128i _s6 = _mm_unpacklo_epi16(_sl3, _sh3); + __m128i _s7 = _mm_unpackhi_epi16(_sl3, _sh3); + + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); + _sum4 = _mm_add_epi32(_sum4, _s4); + _sum5 = _mm_add_epi32(_sum5, _s5); + _sum6 = _mm_add_epi32(_sum6, _s6); + _sum7 = _mm_add_epi32(_sum7, _s7); +#endif // __XOP__ +#endif // __AVX2__ + + pA += 4; + pB += 8; + } + +#if __AVX2__ + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 16), _sum2); + _mm256_store_si256((__m256i*)(outptr + 24), _sum3); +#else // __AVX2__ + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); + _mm_store_si128((__m128i*)(outptr + 16), _sum4); + _mm_store_si128((__m128i*)(outptr + 20), _sum5); + _mm_store_si128((__m128i*)(outptr + 24), _sum6); + _mm_store_si128((__m128i*)(outptr + 28), _sum7); +#endif // __AVX2__ + + outptr += 32; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); + } + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); +#if __AVXVNNIINT8__ + _sum0 = _mm_dpbssd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm_dpbssd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm_dpbssd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm_dpbssd_epi32(_sum3, _pB1, _pA1); +#else // __AVXVNNIINT8__ + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm_comp_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm_comp_dpbusd_epi32(_sum3, _pB1, _pA1); +#endif // __AVXVNNIINT8__ + pA += 16; + pB += 16; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m128i _w_shift0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _w_shift1 = _mm_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + _sum0 = _mm_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm_sub_epi32(_sum3, _w_shift1); + pA += 16; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + + // 0123 + // 2301 + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(1, 0, 3, 2)); + + // 0123 + // 1230 + __m128i _pB0 = _pB; + __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + + pA += 8; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + +#if __XOP__ + // 00112233 + // 22330011 + __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + + // 00112233 + // 1.2.3.0. + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maccd_epi16(_pA1, _pB0, _sum2); + _sum3 = _mm_maccd_epi16(_pA1, _pB1, _sum3); +#else + // 0123 0123 + // 2301 2301 + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + + // 0123 1230 + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)); + + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB01); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB01); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB01); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif + + pA += 4; + pB += 4; + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); + + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); +#if __AVXVNNIINT8__ + _sum0 = _mm_dpbssd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_dpbssd_epi32(_sum1, _pB1, _pA); +#else // __AVXVNNIINT8__ + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA); +#endif // __AVXVNNIINT8__ + pA += 16; + pB += 8; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m128i _w_shift = _mm_loadu_si128((const __m128i*)pA); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + _sum1 = _mm_sub_epi32(_sum1, _w_shift); + pA += 16; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB0 = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB0 = _mm_cvtepi8_epi16(_pB0); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB0 = _mm_unpacklo_epi8(_pB0, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB0)); +#endif + + // 0123 + + // 0101 + // 1010 + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); + + pA += 8; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_set1_epi16(((const short*)pB)[0]); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + +#if __XOP__ + // 00112233 + _pA = _mm_unpacklo_epi16(_pA, _pA); + + // 00110011 + // 1.0.1.0. + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + + _sum0 = _mm_maccd_epi16(_pA, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA, _pB1, _sum1); +#else + // 01230123 + // 01011010 + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 1, 0, 1)); + + __m128i _sl = _mm_mullo_epi16(_pA, _pB01); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); +#endif + + pA += 4; + pB += 2; + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + const signed char* pA = pAT; + + __m128i _sum0; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + } + + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); +#if __AVXVNNIINT8__ + _sum0 = _mm_dpbssd_epi32(_sum0, _pB, _pA); +#else // __AVXVNNIINT8__ +#if __AVX512VNNI__ && _MSC_VER < 1932 + // old msvc crash here --- nihui + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepu8_epi16(_pB); + __m256i _s0 = _mm256_madd_epi16(_pA0, _pB0); + __m128i _s1 = _mm_hadd_epi32(_mm256_extracti128_si256(_s0, 0), _mm256_extracti128_si256(_s0, 1)); + _sum0 = _mm_add_epi32(_sum0, _s1); +#else + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB, _pA); +#endif +#endif // __AVXVNNIINT8__ + pA += 16; + pB += 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m128i _w_shift = _mm_loadu_si128((const __m128i*)pA); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + pA += 16; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_set1_epi16(((const short*)pB)[0]); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB); + + pA += 8; + pB += 2; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_set1_epi16(pB[0]); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); +#endif + +#if __XOP__ + _pA = _mm_unpacklo_epi16(_pA, _pA); + + _sum0 = _mm_maccd_epi16(_pA, _pB, _sum0); +#else + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + + _sum0 = _mm_add_epi32(_sum0, _s0); +#endif + + pA += 4; + pB += 1; + } + + _mm_store_si128((__m128i*)outptr, _sum0); + + outptr += 4; + } + + pAT += max_kk * 4; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + pAT += 16; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* pB = pBT; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + } + + const signed char* pA = pAT; + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pA)[0])); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA); + pA += 8; + pB += 64; + } + if (max_kk >= 4) + { + __m512i _w_shift = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pA)[0])); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift); + pA += 8; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pA)); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0101 0101 0101 0101 + + // 0123 4567 89ab cdef + // 1230 5674 9ab8 defc + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + + pA += 4; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_set1_epi16(((const short*)pA)[0]); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 01010101 01010101 + + // 01234567 89abcdef + // 12305674 9ab8defc + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + + pA += 2; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + + outptr += 32; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; +#else + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; +#endif + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); +#else + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); +#endif + } + else + { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); + _sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8)); +#else + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); +#endif + } + + const signed char* pA = pAT; + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA00 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); + __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB01 = _mm256_loadu_si256((const __m256i*)pB); +#if __AVXVNNIINT8__ + _sum0 = _mm256_dpbssd_epi32(_sum0, _pB01, _pA00); + _sum1 = _mm256_dpbssd_epi32(_sum1, _pB01, _pA11); +#else // __AVXVNNIINT8__ + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB01, _pA00); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB01, _pA11); +#endif // __AVXVNNIINT8__ + pA += 8; + pB += 32; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m256i _w_shift00 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); + __m256i _w_shift11 = _mm256_shuffle_epi32(_w_shift00, _MM_SHUFFLE(2, 3, 0, 1)); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift00); + _sum1 = _mm256_sub_epi32(_sum1, _w_shift11); + pA += 8; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + +#if __AVX2__ + __m256i _pA00 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB01 = _mm256_cvtepi8_epi16(_pB); + + __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); + + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA00, _pB01); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA11, _pB01); +#else // __AVX2__ +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); +#endif + + __m128i _extpB = _mm_cmpgt_epi8(_mm_setzero_si128(), _pB); + __m128i _pB0 = _mm_unpacklo_epi8(_pB, _extpB); + __m128i _pB1 = _mm_unpackhi_epi8(_pB, _extpB); + + // 0101 + // 1010 + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + + // 0123 + // 4567 + + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); +#endif // __AVX2__ + + pA += 4; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_set1_epi16(((const short*)pA)[0]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + + // 01010101 + // 10101010 + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pA, _MM_SHUFFLE(2, 3, 0, 1)), _MM_SHUFFLE(2, 3, 0, 1)); + + // 01234567 + +#if __AVX2__ + __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB)); + __m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB)); + + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); +#else // __AVX2__ + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif // __AVX2__ + + pA += 2; + pB += 8; + } + +#if __AVX2__ + _mm256_storeu_si256((__m256i*)outptr, _sum0); + _mm256_storeu_si256((__m256i*)(outptr + 8), _sum1); +#else + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); +#endif + + outptr += 16; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + } + + const signed char* pA = pAT; + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); +#if __AVXVNNIINT8__ + _sum0 = _mm_dpbssd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_dpbssd_epi32(_sum1, _pB1, _pA); +#else // __AVXVNNIINT8__ + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _pB1, _pA); +#endif // __AVXVNNIINT8__ + pA += 8; + pB += 16; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m128i _w_shift = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + _sum1 = _mm_sub_epi32(_sum1, _w_shift); + pA += 8; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + + // 0101 + + // 0123 + // 1230 + __m128i _pB0 = _pB; + __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(0, 3, 2, 1)); + + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); + + pA += 4; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_set1_epi16(((const short*)pA)[0]); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + + // 01010101 + + // 01231230 + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)); + + __m128i _sl = _mm_mullo_epi16(_pA, _pB01); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + + pA += 2; + pB += 4; + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + + outptr += 8; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + int sum00; + int sum01; + int sum10; + int sum11; + + if (k == 0) + { + sum00 = 0; + sum01 = 0; + sum10 = 0; + sum11 = 0; + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + const signed char* pA = pAT; + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { +#if __AVXVNNIINT8__ + sum00 += pA[0] * pB[0]; + sum00 += pA[1] * pB[1]; + sum00 += pA[2] * pB[2]; + sum00 += pA[3] * pB[3]; + sum01 += pA[0] * pB[4]; + sum01 += pA[1] * pB[5]; + sum01 += pA[2] * pB[6]; + sum01 += pA[3] * pB[7]; + sum10 += pA[4] * pB[0]; + sum10 += pA[5] * pB[1]; + sum10 += pA[6] * pB[2]; + sum10 += pA[7] * pB[3]; + sum11 += pA[4] * pB[4]; + sum11 += pA[5] * pB[5]; + sum11 += pA[6] * pB[6]; + sum11 += pA[7] * pB[7]; +#else // __AVXVNNIINT8__ + sum00 += pA[0] * ((unsigned char*)pB)[0]; + sum00 += pA[1] * ((unsigned char*)pB)[1]; + sum00 += pA[2] * ((unsigned char*)pB)[2]; + sum00 += pA[3] * ((unsigned char*)pB)[3]; + sum01 += pA[0] * ((unsigned char*)pB)[4]; + sum01 += pA[1] * ((unsigned char*)pB)[5]; + sum01 += pA[2] * ((unsigned char*)pB)[6]; + sum01 += pA[3] * ((unsigned char*)pB)[7]; + sum10 += pA[4] * ((unsigned char*)pB)[0]; + sum10 += pA[5] * ((unsigned char*)pB)[1]; + sum10 += pA[6] * ((unsigned char*)pB)[2]; + sum10 += pA[7] * ((unsigned char*)pB)[3]; + sum11 += pA[4] * ((unsigned char*)pB)[4]; + sum11 += pA[5] * ((unsigned char*)pB)[5]; + sum11 += pA[6] * ((unsigned char*)pB)[6]; + sum11 += pA[7] * ((unsigned char*)pB)[7]; +#endif // __AVXVNNIINT8__ + pA += 8; + pB += 8; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + int w_shift0 = ((int*)pA)[0]; + int w_shift1 = ((int*)pA)[1]; + sum00 = sum00 - w_shift0; + sum01 = sum01 - w_shift0; + sum10 = sum10 - w_shift1; + sum11 = sum11 - w_shift1; + pA += 8; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + sum00 += pA[0] * pB[0]; + sum00 += pA[1] * pB[1]; + sum01 += pA[0] * pB[2]; + sum01 += pA[1] * pB[3]; + sum10 += pA[2] * pB[0]; + sum10 += pA[3] * pB[1]; + sum11 += pA[2] * pB[2]; + sum11 += pA[3] * pB[3]; + pA += 4; + pB += 4; + } +#endif // __SSE2__ + for (; kk < max_kk; kk += 1) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[0] * pB[1]; + sum10 += pA[1] * pB[0]; + sum11 += pA[1] * pB[1]; + + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + + outptr += 4; + } + for (; jj < max_jj; jj += 1) + { + int sum0; + int sum1; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const signed char* pA = pAT; + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { +#if __AVXVNNIINT8__ + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum0 += pA[2] * pB[2]; + sum0 += pA[3] * pB[3]; + sum1 += pA[4] * pB[0]; + sum1 += pA[5] * pB[1]; + sum1 += pA[6] * pB[2]; + sum1 += pA[7] * pB[3]; +#else // __AVXVNNIINT8__ + sum0 += pA[0] * ((unsigned char*)pB)[0]; + sum0 += pA[1] * ((unsigned char*)pB)[1]; + sum0 += pA[2] * ((unsigned char*)pB)[2]; + sum0 += pA[3] * ((unsigned char*)pB)[3]; + sum1 += pA[4] * ((unsigned char*)pB)[0]; + sum1 += pA[5] * ((unsigned char*)pB)[1]; + sum1 += pA[6] * ((unsigned char*)pB)[2]; + sum1 += pA[7] * ((unsigned char*)pB)[3]; +#endif // __AVXVNNIINT8__ + pA += 8; + pB += 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + int w_shift0 = ((int*)pA)[0]; + int w_shift1 = ((int*)pA)[1]; + sum0 = sum0 - w_shift0; + sum1 = sum1 - w_shift1; + pA += 8; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[2] * pB[0]; + sum1 += pA[3] * pB[1]; + pA += 4; + pB += 2; + } +#endif // __SSE2__ + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + + outptr += 2; + } + + pAT += max_kk * 2; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + pAT += 8; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + } + for (; ii < max_ii; ii += 1) + { + const signed char* pB = pBT; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + } + + const signed char* pA = pAT; + int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pB = _mm512_loadu_si512((const __m512i*)pB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB, _pA); + pA += 4; + pB += 64; + } + if (max_kk >= 4) + { + __m512i _w_shift = _mm512_set1_epi32(((const int*)pA)[0]); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift); + pA += 4; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_set1_epi16(((const short*)pA)[0]); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + + pA += 2; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m256i _pA = _mm256_set1_epi16(pA[0]); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA, _pB0))); + + pA += 1; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + + outptr += 16; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { +#if __AVX2__ + __m256i _sum0; +#else + __m128i _sum0; + __m128i _sum1; +#endif + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); +#else + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); +#endif + } + else + { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); +#else + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); +#endif + } + + const signed char* pA = pAT; + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA00 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pA)); + __m256i _pB01 = _mm256_loadu_si256((const __m256i*)pB); +#if __AVXVNNIINT8__ + _sum0 = _mm256_dpbssd_epi32(_sum0, _pB01, _pA00); +#else // __AVXVNNIINT8__ + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB01, _pA00); +#endif // __AVXVNNIINT8__ + pA += 4; + pB += 32; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m256i _w_shift = _mm256_set1_epi32(((const int*)pA)[0]); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift); + pA += 4; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_set1_epi16(((const short*)pA)[0]); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + +#if __AVX2__ + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); +#else +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); +#endif + + __m128i _extpB = _mm_cmpgt_epi8(_mm_setzero_si128(), _pB); + __m128i _pB0 = _mm_unpacklo_epi8(_pB, _extpB); + __m128i _pB1 = _mm_unpackhi_epi8(_pB, _extpB); + + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); +#endif // __AVX2__ + + pA += 2; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_set1_epi16(pA[0]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + +#if __SSE4_1__ + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + +#if __AVX2__ + __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB)); + + _sum0 = _mm256_add_epi32(_sum0, _s0); +#else + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); +#endif // __AVX2__ + + pA += 1; + pB += 8; + } + +#if __AVX2__ + _mm256_storeu_si256((__m256i*)outptr, _sum0); +#else + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); +#endif + + outptr += 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m128i _sum0; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + } + + const signed char* pA = pAT; + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); +#if __AVXVNNIINT8__ + _sum0 = _mm_dpbssd_epi32(_sum0, _pB, _pA); +#else // __AVXVNNIINT8__ + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _pB, _pA); +#endif // __AVXVNNIINT8__ + pA += 4; + pB += 16; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + __m128i _w_shift = _mm_set1_epi32(((const int*)pA)[0]); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + pA += 4; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + +#if __SSE4_1__ + _pA = _mm_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pA = _mm_unpacklo_epi8(_pA, _mm_cmpgt_epi8(_mm_setzero_si128(), _pA)); + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + + // 0xxx -> 0000 + __m128i _pA0 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(0, 0, 0, 0)); + + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB); + + pA += 2; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_set1_epi16(pA[0]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + +#if __SSE4_1__ + _pB = _mm_cvtepi8_epi16(_pB); +#else + _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); +#endif + + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + + _sum0 = _mm_add_epi32(_sum0, _s0); + + pA += 1; + pB += 4; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + + outptr += 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + int sum0; + int sum1; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const signed char* pA = pAT; + int kk = 0; +#if __SSE2__ +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { +#if __AVXVNNIINT8__ + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum0 += pA[2] * pB[2]; + sum0 += pA[3] * pB[3]; + sum1 += pA[0] * pB[4]; + sum1 += pA[1] * pB[5]; + sum1 += pA[2] * pB[6]; + sum1 += pA[3] * pB[7]; +#else // __AVXVNNIINT8__ + sum0 += pA[0] * ((unsigned char*)pB)[0]; + sum0 += pA[1] * ((unsigned char*)pB)[1]; + sum0 += pA[2] * ((unsigned char*)pB)[2]; + sum0 += pA[3] * ((unsigned char*)pB)[3]; + sum1 += pA[0] * ((unsigned char*)pB)[4]; + sum1 += pA[1] * ((unsigned char*)pB)[5]; + sum1 += pA[2] * ((unsigned char*)pB)[6]; + sum1 += pA[3] * ((unsigned char*)pB)[7]; +#endif // __AVXVNNIINT8__ + pA += 4; + pB += 8; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + int w_shift = ((const int*)pA)[0]; + sum0 = sum0 - w_shift; + sum1 = sum1 - w_shift; + pA += 4; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[0] * pB[2]; + sum1 += pA[1] * pB[3]; + pA += 2; + pB += 4; + } +#endif // __SSE2__ + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + int sum; + + if (k == 0) + { + sum = 0; + } + else + { + sum = outptr[0]; + } + + const signed char* pA = pAT; + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ + for (; kk + 3 < max_kk; kk += 4) + { +#if __AVXVNNIINT8__ + sum += pA[0] * pB[0]; + sum += pA[1] * pB[1]; + sum += pA[2] * pB[2]; + sum += pA[3] * pB[3]; +#else // __AVXVNNIINT8__ + sum += pA[0] * ((unsigned char*)pB)[0]; + sum += pA[1] * ((unsigned char*)pB)[1]; + sum += pA[2] * ((unsigned char*)pB)[2]; + sum += pA[3] * ((unsigned char*)pB)[3]; +#endif // __AVXVNNIINT8__ + pA += 4; + pB += 4; + } +#if !__AVXVNNIINT8__ + if (max_kk >= 4) + { + int w_shift = ((const int*)pA)[0]; + sum = sum - w_shift; + pA += 4; + } +#endif // !__AVXVNNIINT8__ +#endif // __AVX512VNNI__ || __AVXVNNI__ + for (; kk < max_kk; kk += 1) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + outptr[0] = sum; + + outptr += 1; + } + + pAT += max_kk; +#if __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + if (max_kk >= 4) + { + pAT += 4; + } +#endif // __AVX512VNNI__ || (__AVXVNNI__ && !__AVXVNNIINT8__) + } +} + +static void get_optimal_tile_mnk_int8(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size = get_cpu_level2_cache_size(); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + int tile_size = (int)sqrtf((float)l2_cache_size / (2 * sizeof(signed char) + sizeof(int))); + +#if __AVX512F__ + TILE_M = std::max(16, tile_size / 16 * 16); + TILE_N = std::max(16, tile_size / 16 * 16); + TILE_K = std::max(16, tile_size / 16 * 16); +#elif __AVX__ + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(8, tile_size / 8 * 8); + TILE_K = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_M = std::max(4, tile_size / 4 * 4); + TILE_N = std::max(4, tile_size / 4 * 4); + TILE_K = std::max(4, tile_size / 4 * 4); +#else + TILE_M = std::max(2, tile_size / 2 * 2); + TILE_N = std::max(1, tile_size); + TILE_K = std::max(2, tile_size / 2 * 2); +#endif + + if (K > 0) + { + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __AVX512F__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 15) / 16 * 16); +#elif __AVX__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __SSE2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); +#endif + + if (nn_K == 1) + { + tile_size = (int)((float)l2_cache_size / 2 / sizeof(signed char) / TILE_K); + +#if __AVX512F__ + TILE_M = std::max(16, tile_size / 16 * 16); + TILE_N = std::max(16, tile_size / 16 * 16); +#elif __AVX__ + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_M = std::max(4, tile_size / 4 * 4); + TILE_N = std::max(4, tile_size / 4 * 4); +#else + TILE_M = std::max(2, tile_size / 2 * 2); + TILE_N = std::max(1, tile_size); +#endif + } + } + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + if (M > 0) + { + int nn_M = (M + TILE_M - 1) / TILE_M; +#if __AVX512F__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); +#elif __AVX__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + } + + if (N > 0) + { + int nn_N = (N + TILE_N - 1) / TILE_N; +#if __AVX512F__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 15) / 16 * 16); +#elif __AVX__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 7) / 8 * 8); +#elif __SSE2__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif + } + + if (nT > 1) + { +#if __AVX512F__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 15) / 16 * 16); +#elif __AVX__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); +#endif + } + + // always take constant TILE_M/N/K value when provided + if (constant_TILE_M > 0) + { +#if __AVX512F__ + TILE_M = (constant_TILE_M + 15) / 16 * 16; +#elif __AVX__ + TILE_M = (constant_TILE_M + 7) / 8 * 8; +#elif __SSE2__ + TILE_M = (constant_TILE_M + 3) / 4 * 4; +#else + TILE_M = (constant_TILE_M + 1) / 2 * 2; +#endif + } + + if (constant_TILE_N > 0) + { +#if __AVX512F__ + TILE_N = (constant_TILE_N + 15) / 16 * 16; +#elif __AVX__ + TILE_N = (constant_TILE_N + 7) / 8 * 8; +#elif __SSE2__ + TILE_N = (constant_TILE_N + 3) / 4 * 4; +#else + TILE_N = constant_TILE_N; +#endif + } + + if (constant_TILE_K > 0) + { +#if __AVX512F__ + TILE_K = (constant_TILE_K + 15) / 16 * 16; +#elif __AVX__ + TILE_K = (constant_TILE_K + 7) / 8 * 8; +#elif __SSE2__ + TILE_K = (constant_TILE_K + 3) / 4 * 4; +#else + TILE_K = (constant_TILE_K + 1) / 2 * 2; +#endif + } +} diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index 268f85f332d8..9d89f034b5a0 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -16,8 +16,13 @@ #if __SSE2__ #include +#include "sse_mathfun.h" #if __AVX__ #include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ #endif // __AVX__ #endif // __SSE2__ #include "x86_usability.h" @@ -26,6 +31,10 @@ namespace ncnn { +#if NCNN_INT8 +#include "gemm_int8.h" +#endif + Gemm_x86::Gemm_x86() { #if __SSE2__ @@ -7225,8 +7234,7 @@ int Gemm_x86::create_pipeline(const Option& opt) #if NCNN_INT8 if (int8_scale_term) { - support_packing = false; - return 0; + return create_pipeline_int8(opt); } #endif @@ -7366,7 +7374,8 @@ int Gemm_x86::forward(const std::vector& bottom_blobs, std::vector& to #if NCNN_INT8 if (int8_scale_term) { - return Gemm::forward_int8(bottom_blobs, top_blobs, opt); + // return Gemm::forward_int8(bottom_blobs, top_blobs, opt); + return forward_int8(bottom_blobs, top_blobs, opt); } #endif @@ -7557,4 +7566,808 @@ int Gemm_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } +#if NCNN_INT8 +static void compute_A_tile_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + compute_A_tile_fp32_int8_scales(A, scales, B_scale, out_descales, i, max_ii); +} + +static void transpose_compute_A_tile_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + transpose_compute_A_tile_fp32_int8_scales(A, scales, B_scale, out_descales, i, max_ii); +} + +static void pack_A_tile_quantize(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +static void transpose_pack_A_tile_quantize(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +static void compute_B_int8_scale(const Mat& B, float& scale) +{ + compute_B_fp32_int8_scale(B, scale); +} + +static void pack_B_tile_quantize(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +static void transpose_pack_B_tile_quantize(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +static void unpack_output_tile_dequantize(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose) +{ + unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta, output_transpose); +} + +struct gemm_x86_int8_omp_args +{ + int TILE_M; + int TILE_N; + int TILE_K; + int broadcast_type_C; + int transA; + int output_transpose; + float alpha; + float beta; +}; + +static int gemm_x86_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + // NCNN_LOGE("gemm_x86_int8"); + + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat ATX; +#if NCNN_AVX512VNNI || NCNN_AVXVNNI + bool has_w_shift = false; + if (TILE_K >= 4) + { + has_w_shift = ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni(); +#if NCNN_AVXVNNIINT8 + if (ncnn::cpu_support_x86_avx_vnni_int8()) + has_w_shift = false; +#endif // NCNN_AVXVNNIINT8 + } + if (has_w_shift) + { + int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1; + ATX.create((TILE_K + w_shift_count * 4) * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + } + else +#endif // NCNN_AVX512VNNI || NCNN_AVXVNNI + { + ATX.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + } + if (ATX.empty()) + return -100; + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 1u, opt.workspace_allocator); + if (BT.empty()) + return -100; + + const int nn_NK = nn_N * nn_K; + + Mat A_int8_scales(M, 4u, opt.workspace_allocator); + if (A_int8_scales.empty()) + return -100; + + // dynamic quantize B + float B_int8_scale; + compute_B_int8_scale(B, B_int8_scale); + + // const float output_descale = 1.f / (A_int8_scale * B_int8_scale); + Mat output_descales(M, 4u, opt.workspace_allocator); + if (output_descales.empty()) + return -100; + + // NCNN_LOGE("arm ds %f %f", 1/A_int8_scale, 1/B_int8_scale); + + // pack B + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + pack_B_tile_quantize(B, BT_tile, j, max_jj, k, max_kk, B_int8_scale); + else + transpose_pack_B_tile_quantize(B, BT_tile, j, max_jj, k, max_kk, B_int8_scale); + } + + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + const struct gemm_x86_int8_omp_args args = {TILE_M, TILE_N, TILE_K, broadcast_type_C, transA, output_transpose, alpha, beta}; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + // shadowed variable for less openmp task args + const int TILE_M = args.TILE_M; + const int TILE_N = args.TILE_N; + const int TILE_K = args.TILE_K; + const int broadcast_type_C = args.broadcast_type_C; + const int transA = args.transA; + const int output_transpose = args.output_transpose; + const float alpha = args.alpha; + const float beta = args.beta; + // const int input_elemtype = args.input_elemtype; + // const int output_elemtype = args.output_elemtype; + + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (k == 0) + { + if (transA) + transpose_compute_A_tile_int8_scales(A, A_int8_scales, B_int8_scale, output_descales, i, max_ii); + else + compute_A_tile_int8_scales(A, A_int8_scales, B_int8_scale, output_descales, i, max_ii); + + // NCNN_LOGE("A_int8_scales %f B_int8_scale %f", A_int8_scales[0], B_int8_scale); + } + + if (transA) + transpose_pack_A_tile_quantize(A, AT_tile, i, max_ii, k, max_kk, A_int8_scales); + else + pack_A_tile_quantize(A, AT_tile, i, max_ii, k, max_kk, A_int8_scales); + } + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + } + + unpack_output_tile_dequantize(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta, output_transpose); + } + } + + return 0; +} + +static int gemm_AT_x86_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + // NCNN_LOGE("gemm_AT_x86_int8"); + + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 1u, opt.workspace_allocator); + if (BT.empty()) + return -100; + + const int nn_NK = nn_N * nn_K; + + // dynamic quantize B + float B_int8_scale; + compute_B_int8_scale(B, B_int8_scale); + + // NCNN_LOGE("%.4f %.4f", A_int8_scale, B_int8_scale); + + // const float output_descale = 1.f / (A_int8_scale * B_int8_scale); + Mat output_descales(M, 4u, opt.workspace_allocator); + if (output_descales.empty()) + return -100; + + for (int i = 0; i < M; i++) + { + output_descales[i] = 1.f / (A_int8_scales[i] * B_int8_scale); + } + + // pack B + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + pack_B_tile_quantize(B, BT_tile, j, max_jj, k, max_kk, B_int8_scale); + else + transpose_pack_B_tile_quantize(B, BT_tile, j, max_jj, k, max_kk, B_int8_scale); + } + + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + const struct gemm_x86_int8_omp_args args = {TILE_M, TILE_N, TILE_K, broadcast_type_C, 0, output_transpose, alpha, beta}; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + // shadowed variable for less openmp task args + const int TILE_M = args.TILE_M; + const int TILE_N = args.TILE_N; + const int TILE_K = args.TILE_K; + const int broadcast_type_C = args.broadcast_type_C; + const int output_transpose = args.output_transpose; + const float alpha = args.alpha; + const float beta = args.beta; + // const int output_elemtype = args.output_elemtype; + + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + } + + unpack_output_tile_dequantize(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta, output_transpose); + } + } + + return 0; +} + +static int gemm_BT_x86_int8(const Mat& A, const Mat& BT, float B_int8_scale, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + // NCNN_LOGE("gemm_BT_x86_int8"); + + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + // int nn_N = (N + TILE_N - 1) / TILE_N; + + Mat A_int8_scales(M, 4u, opt.workspace_allocator); + if (A_int8_scales.empty()) + return -100; + + // const float output_descale = 1.f / (A_int8_scale * B_int8_scale); + Mat output_descales(M, 4u, opt.workspace_allocator); + if (output_descales.empty()) + return -100; + + // NCNN_LOGE("scale %.4f %.4f", A_int8_scale, B_int8_scale); + + Mat ATX; +#if NCNN_AVX512VNNI || NCNN_AVXVNNI + bool has_w_shift = false; + if (TILE_K >= 4) + { + has_w_shift = ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni(); +#if NCNN_AVXVNNIINT8 + if (ncnn::cpu_support_x86_avx_vnni_int8()) + has_w_shift = false; +#endif // NCNN_AVXVNNIINT8 + } + if (has_w_shift) + { + int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1; + ATX.create((TILE_K + w_shift_count * 4) * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + } + else +#endif // NCNN_AVX512VNNI || NCNN_AVXVNNI + { + ATX.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + } + if (ATX.empty()) + return -100; + + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + const struct gemm_x86_int8_omp_args args = {TILE_M, TILE_N, TILE_K, broadcast_type_C, transA, output_transpose, alpha, beta}; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + // shadowed variable for less openmp task args + const int TILE_M = args.TILE_M; + const int TILE_N = args.TILE_N; + const int TILE_K = args.TILE_K; + const int broadcast_type_C = args.broadcast_type_C; + const int transA = args.transA; + const int output_transpose = args.output_transpose; + const float alpha = args.alpha; + const float beta = args.beta; + // const int input_elemtype = args.input_elemtype; + // const int output_elemtype = args.output_elemtype; + + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (k == 0) + { + if (transA) + transpose_compute_A_tile_int8_scales(A, A_int8_scales, B_int8_scale, output_descales, i, max_ii); + else + compute_A_tile_int8_scales(A, A_int8_scales, B_int8_scale, output_descales, i, max_ii); + + // NCNN_LOGE("A_int8_scales %f B_int8_scale %f", A_int8_scales[0], B_int8_scale); + } + + if (transA) + transpose_pack_A_tile_quantize(A, AT_tile, i, max_ii, k, max_kk, A_int8_scales); + else + pack_A_tile_quantize(A, AT_tile, i, max_ii, k, max_kk, A_int8_scales); + } + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + } + + unpack_output_tile_dequantize(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta, output_transpose); + } + } + + return 0; +} + +static int gemm_AT_BT_x86_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& BT, float B_int8_scale, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + // NCNN_LOGE("gemm_AT_BT_x86_int8"); + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + // int nn_N = (N + TILE_N - 1) / TILE_N; + + // const float output_descale = 1.f / (A_int8_scale * B_int8_scale); + Mat output_descales(M, 4u, opt.workspace_allocator); + if (output_descales.empty()) + return -100; + + for (int i = 0; i < M; i++) + { + output_descales[i] = 1.f / (A_int8_scales[i] * B_int8_scale); + } + + Mat topT(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + if (topT.empty()) + return -100; + + const struct gemm_x86_int8_omp_args args = {TILE_M, TILE_N, TILE_K, broadcast_type_C, 0, output_transpose, alpha, beta}; + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + // shadowed variable for less openmp task args + const int TILE_M = args.TILE_M; + const int TILE_N = args.TILE_N; + const int TILE_K = args.TILE_K; + const int broadcast_type_C = args.broadcast_type_C; + const int output_transpose = args.output_transpose; + const float alpha = args.alpha; + const float beta = args.beta; + // const int output_elemtype = args.output_elemtype; + + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + } + + unpack_output_tile_dequantize(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta, output_transpose); + } + } + + return 0; +} + +int Gemm_x86::create_pipeline_int8(const Option& opt) +{ + if (constantA) + { + const int M = constantM; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + +#if NCNN_AVX512VNNI || NCNN_AVXVNNI + bool has_w_shift = false; + if (TILE_K >= 4) + { + has_w_shift = ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni(); +#if NCNN_AVXVNNIINT8 + if (ncnn::cpu_support_x86_avx_vnni_int8()) + has_w_shift = false; +#endif // NCNN_AVXVNNIINT8 + } + if (has_w_shift) + { + int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1; + AT_data.create((TILE_K + w_shift_count * 4) * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 1u, (Allocator*)0); + } + else +#endif // NCNN_AVX512VNNI || NCNN_AVXVNNI + { + AT_data.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 1u, (Allocator*)0); + } + if (AT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT_data.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile_int8(A_data, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_int8(A_data, AT_tile, i, max_ii, k, max_kk); + } + } + } + + if (opt.lightmode) + A_data.release(); + } + + if (constantB) + { + const int N = constantN; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(0, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_N = (N + TILE_N - 1) / TILE_N; + + BT_data.create(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 1u, (Allocator*)0); + if (BT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_N; ppj++) + { + const int j = ppj * TILE_N; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT_data.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_int8(B_data, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_int8(B_data, BT_tile, j, max_jj, k, max_kk); + } + } + } + + if (opt.lightmode) + B_data.release(); + } + + if (constantC && constant_broadcast_type_C != -1) + { + CT_data = C_data; + + if (opt.lightmode) + C_data.release(); + } + + if (constantA || constantB || constantC) + { + nT = opt.num_threads; + } + + return 0; +} + +int Gemm_x86::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int M; + int N; + if (constantA && constantB) + { + M = constantM; + N = constantN; + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + M = constantM; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = constantN; + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = CT_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB) + { + C = bottom_blobs.size() == 1 ? bottom_blobs[0] : Mat(); + } + else if (constantA) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else if (constantB) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else + { + C = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w * C.elempack == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w * C.elempack == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h * C.elempack == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == 1) + { + // 1xN + broadcast_type_C = 4; + } + } + } + + int out_elempack = 1; +#if __SSE2__ + if (opt.use_packing_layout) + { + int outh = output_transpose ? N : M; +#if __AVX512F__ + out_elempack = outh % 16 == 0 ? 16 : outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; +#elif __AVX__ + out_elempack = outh % 8 == 0 ? 8 : outh % 4 == 0 ? 4 : 1; +#else + out_elempack = outh % 4 == 0 ? 4 : 1; +#endif + } +#endif // __SSE2__ + + // FIXME use output_elempack + // int output_elempack = out_elempack > 4 ? 4 : out_elempack; + + if (output_elempack) + out_elempack = output_elempack; + size_t out_elemsize = 4u * out_elempack; + + // FIXME use output_elemtype instead of input_elemtype + // int output_elemtype = input_elemtype; + + // TODO use output_elemtype + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(M, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(N, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, gemm will use load-time value %d", opt.num_threads, nT); + } + + int ret = 0; + if (constantA && constantB) + { + ret = gemm_AT_BT_x86_int8(AT_data, A_data_int8_scales, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + ret = gemm_AT_x86_int8(AT_data, A_data_int8_scales, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + ret = gemm_BT_x86_int8(A, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + ret = gemm_x86_int8(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, alpha, beta, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + + return ret; +} +#endif + } // namespace ncnn diff --git a/src/layer/x86/gemm_x86.h b/src/layer/x86/gemm_x86.h index 6f8eb4a82bfc..b7833ae6d598 100644 --- a/src/layer/x86/gemm_x86.h +++ b/src/layer/x86/gemm_x86.h @@ -28,6 +28,12 @@ class Gemm_x86 : public Gemm virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +protected: +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif + public: int nT; Mat AT_data; diff --git a/src/layer/x86/gemm_x86_avx2.cpp b/src/layer/x86/gemm_x86_avx2.cpp new file mode 100644 index 000000000000..ccc161240c6d --- /dev/null +++ b/src/layer/x86/gemm_x86_avx2.cpp @@ -0,0 +1,81 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +void pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_avx2(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void unpack_output_tile_int32_to_fp32_avx2(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose) +{ + unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta, output_transpose); +} + +void gemm_transB_packed_tile_int8_avx2(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +} // namespace ncnn diff --git a/src/layer/x86/gemm_x86_avx512vnni.cpp b/src/layer/x86/gemm_x86_avx512vnni.cpp new file mode 100644 index 000000000000..fd72dd66d205 --- /dev/null +++ b/src/layer/x86/gemm_x86_avx512vnni.cpp @@ -0,0 +1,79 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +void pack_A_tile_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void gemm_transB_packed_tile_int8_avx512vnni(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +} // namespace ncnn diff --git a/src/layer/x86/gemm_x86_avxvnni.cpp b/src/layer/x86/gemm_x86_avxvnni.cpp new file mode 100644 index 000000000000..b738df821793 --- /dev/null +++ b/src/layer/x86/gemm_x86_avxvnni.cpp @@ -0,0 +1,76 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +void pack_A_tile_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void gemm_transB_packed_tile_int8_avxvnni(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +} // namespace ncnn diff --git a/src/layer/x86/gemm_x86_avxvnniint8.cpp b/src/layer/x86/gemm_x86_avxvnniint8.cpp new file mode 100644 index 000000000000..ca1d485ef318 --- /dev/null +++ b/src/layer/x86/gemm_x86_avxvnniint8.cpp @@ -0,0 +1,76 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +void pack_A_tile_int8_avxvnniint8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_avxvnniint8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_avxvnniint8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_avxvnniint8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_avxvnniint8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_avxvnniint8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_avxvnniint8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_avxvnniint8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void gemm_transB_packed_tile_int8_avxvnniint8(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +} // namespace ncnn diff --git a/src/layer/x86/gemm_x86_xop.cpp b/src/layer/x86/gemm_x86_xop.cpp new file mode 100644 index 000000000000..1673d6317d01 --- /dev/null +++ b/src/layer/x86/gemm_x86_xop.cpp @@ -0,0 +1,36 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +void gemm_transB_packed_tile_int8_xop(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +} // namespace ncnn diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h index 0bc9cda343a5..a2ee972ea01b 100644 --- a/src/layer/x86/lstm_int8.h +++ b/src/layer/x86/lstm_int8.h @@ -18,7 +18,7 @@ void lstm_dynamic_quantize_scale2int8_avx512vnni(const float* ptr, int size, flo void lstm_int8_avx512vnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ void lstm_transform_weight_int8_avxvnni(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt); void lstm_dynamic_quantize_scale2int8_avxvnni(const float* ptr, int size, float scale, signed char* outptr); void lstm_int8_avxvnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); @@ -43,7 +43,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { lstm_transform_weight_int8_avxvnni(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); @@ -870,10 +870,10 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); - _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); - _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); - _w2_shift = _mm256_dpbusd_epi32(_w2_shift, _v127, _w2); - _w3_shift = _mm256_dpbusd_epi32(_w3_shift, _v127, _w3); + _w0_shift = _mm256_comp_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_comp_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm256_comp_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm256_comp_dpbusd_epi32(_w3_shift, _v127, _w3); kptr += 128; } @@ -900,8 +900,8 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); - _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + _w0_shift = _mm256_comp_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_comp_dpbusd_epi32(_w1_shift, _v127, _w1); kptr += 64; } @@ -946,7 +946,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr[24 + 7] = weight_xc_G_1[i + 3]; __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); - _w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _w); + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _w); kptr += 32; } @@ -1062,10 +1062,10 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); - _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); - _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); - _w2_shift = _mm256_dpbusd_epi32(_w2_shift, _v127, _w2); - _w3_shift = _mm256_dpbusd_epi32(_w3_shift, _v127, _w3); + _w0_shift = _mm256_comp_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_comp_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm256_comp_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm256_comp_dpbusd_epi32(_w3_shift, _v127, _w3); kptr += 128; } @@ -1092,8 +1092,8 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); - _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + _w0_shift = _mm256_comp_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_comp_dpbusd_epi32(_w1_shift, _v127, _w1); kptr += 64; } @@ -1138,7 +1138,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr[24 + 7] = weight_hc_G_1[i + 3]; __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); - _w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _w); + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _w); kptr += 32; } @@ -1299,10 +1299,10 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); - _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); - _w2_shift = _mm_dpbusd_epi32(_w2_shift, _v127, _w2); - _w3_shift = _mm_dpbusd_epi32(_w3_shift, _v127, _w3); + _w0_shift = _mm_comp_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_comp_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm_comp_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm_comp_dpbusd_epi32(_w3_shift, _v127, _w3); kptr += 64; } @@ -1326,8 +1326,8 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); - _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + _w0_shift = _mm_comp_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_comp_dpbusd_epi32(_w1_shift, _v127, _w1); kptr += 32; } @@ -1356,7 +1356,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr[8 + 7] = weight_xc_G[i + 3]; __m128i _w = _mm_loadu_si128((const __m128i*)kptr); - _w_shift = _mm_dpbusd_epi32(_w_shift, _v127, _w); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _w); kptr += 16; } @@ -1437,10 +1437,10 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); - _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); - _w2_shift = _mm_dpbusd_epi32(_w2_shift, _v127, _w2); - _w3_shift = _mm_dpbusd_epi32(_w3_shift, _v127, _w3); + _w0_shift = _mm_comp_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_comp_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm_comp_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm_comp_dpbusd_epi32(_w3_shift, _v127, _w3); kptr += 64; } @@ -1464,8 +1464,8 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); - _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + _w0_shift = _mm_comp_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_comp_dpbusd_epi32(_w1_shift, _v127, _w1); kptr += 32; } @@ -1494,7 +1494,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr[8 + 7] = weight_hc_G[i + 3]; __m128i _w = _mm_loadu_si128((const __m128i*)kptr); - _w_shift = _mm_dpbusd_epi32(_w_shift, _v127, _w); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _w); kptr += 16; } @@ -1621,7 +1621,7 @@ static void lstm_dynamic_quantize_scale2int8(const float* ptr, int size, float s } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { lstm_dynamic_quantize_scale2int8_avxvnni(ptr, size, scale, outptr); @@ -1705,7 +1705,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { lstm_int8_avxvnni(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); @@ -2004,11 +2004,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _xixi0 = _mm512_shuffle_epi32(_xixi, _MM_PERM_AAAA); -#if __AVX512VNNI__ - _lstm_IFOGx0 = _mm512_dpwssd_epi32(_lstm_IFOGx0, _ww, _xixi0); -#else - _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _mm512_madd_epi16(_ww, _xixi0)); -#endif // __AVX512VNNI__ + _lstm_IFOGx0 = _mm512_comp_dpwssd_epi32(_lstm_IFOGx0, _ww, _xixi0); kptr += 32; } @@ -2191,11 +2187,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _hh_cont0 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_AAAA); -#if __AVX512VNNI__ - _lstm_IFOGh0 = _mm512_dpwssd_epi32(_lstm_IFOGh0, _ww, _hh_cont0); -#else - _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _mm512_madd_epi16(_ww, _hh_cont0)); -#endif // __AVX512VNNI__ + _lstm_IFOGh0 = _mm512_comp_dpwssd_epi32(_lstm_IFOGh0, _ww, _hh_cont0); kptr += 32; } @@ -2273,10 +2265,10 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _xii = _mm256_inserti128_si256(_mm256_castsi128_si256(_xi), _xi, 1); - _sum0 = _mm256_dpbusd_epi32(_sum0, _xii, _w0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _xii, _w1); - _sum2 = _mm256_dpbusd_epi32(_sum2, _xii, _w2); - _sum3 = _mm256_dpbusd_epi32(_sum3, _xii, _w3); + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _xii, _w0); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _xii, _w1); + _sum2 = _mm256_comp_dpbusd_epi32(_sum2, _xii, _w2); + _sum3 = _mm256_comp_dpbusd_epi32(_sum3, _xii, _w3); kptr += 128; } @@ -2296,8 +2288,8 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - _sum0 = _mm256_dpbusd_epi32(_sum0, _xi, _w0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _xi, _w1); + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _xi, _w1); kptr += 64; } @@ -2314,7 +2306,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #ifdef _MSC_VER _xi = _mm256_add_epi32(_xi, _mm256_set1_epi8(127)); #endif - _lstm_IFOGx0 = _mm256_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); + _lstm_IFOGx0 = _mm256_comp_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); kptr += 32; } @@ -2394,11 +2386,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _xixi0 = _mm256_shuffle_epi32(_xixi, _MM_SHUFFLE(0, 0, 0, 0)); -#if __AVXVNNI__ || __AVX512VNNI__ - _lstm_IFOGx0 = _mm256_dpwssd_epi32(_lstm_IFOGx0, _ww, _xixi0); -#else - _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _mm256_madd_epi16(_ww, _xixi0)); -#endif // __AVXVNNI__ || __AVX512VNNI__ + _lstm_IFOGx0 = _mm256_comp_dpwssd_epi32(_lstm_IFOGx0, _ww, _xixi0); kptr += 16; } @@ -2434,10 +2422,10 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _hh_cont = _mm256_broadcastsi128_si256(_h_cont); - _sum0 = _mm256_dpbusd_epi32(_sum0, _hh_cont, _w0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _hh_cont, _w1); - _sum2 = _mm256_dpbusd_epi32(_sum2, _hh_cont, _w2); - _sum3 = _mm256_dpbusd_epi32(_sum3, _hh_cont, _w3); + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _hh_cont, _w0); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _hh_cont, _w1); + _sum2 = _mm256_comp_dpbusd_epi32(_sum2, _hh_cont, _w2); + _sum3 = _mm256_comp_dpbusd_epi32(_sum3, _hh_cont, _w3); kptr += 128; } @@ -2457,8 +2445,8 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - _sum0 = _mm256_dpbusd_epi32(_sum0, _h_cont, _w0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _h_cont, _w1); + _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _h_cont, _w1); kptr += 64; } @@ -2475,7 +2463,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #ifdef _MSC_VER _h_cont = _mm256_add_epi32(_h_cont, _mm256_set1_epi8(127)); #endif - _lstm_IFOGh0 = _mm256_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); + _lstm_IFOGh0 = _mm256_comp_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); kptr += 32; } @@ -2555,11 +2543,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _hh_cont0 = _mm256_shuffle_epi32(_hh_cont, _MM_SHUFFLE(0, 0, 0, 0)); -#if __AVXVNNI__ || __AVX512VNNI__ - _lstm_IFOGh0 = _mm256_dpwssd_epi32(_lstm_IFOGh0, _ww, _hh_cont0); -#else - _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _mm256_madd_epi16(_ww, _hh_cont0)); -#endif // __AVXVNNI__ || __AVX512VNNI__ + _lstm_IFOGh0 = _mm256_comp_dpwssd_epi32(_lstm_IFOGh0, _ww, _hh_cont0); kptr += 16; } @@ -2635,10 +2619,10 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - _sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0); - _sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1); - _sum2 = _mm_dpbusd_epi32(_sum2, _xi, _w2); - _sum3 = _mm_dpbusd_epi32(_sum3, _xi, _w3); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _xi, _w1); + _sum2 = _mm_comp_dpbusd_epi32(_sum2, _xi, _w2); + _sum3 = _mm_comp_dpbusd_epi32(_sum3, _xi, _w3); kptr += 64; } @@ -2659,8 +2643,8 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - _sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0); - _sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _xi, _w1); kptr += 32; } @@ -2677,7 +2661,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #ifdef _MSC_VER _xi = _mm_add_epi32(_xi, _mm_set1_epi8(127)); #endif - _lstm_IFOGx0 = _mm_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); + _lstm_IFOGx0 = _mm_comp_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); kptr += 16; } @@ -2712,21 +2696,10 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _w3 = _mm_unpacklo_epi8(_w3, _mm_cmpgt_epi8(_mm_setzero_si128(), _w3)); #endif -#if __XOP__ - _sum0 = _mm_maddd_epi16(_w0, _xi, _sum0); - _sum1 = _mm_maddd_epi16(_w1, _xi, _sum1); - _sum2 = _mm_maddd_epi16(_w2, _xi, _sum2); - _sum3 = _mm_maddd_epi16(_w3, _xi, _sum3); -#else - __m128i _s0 = _mm_madd_epi16(_w0, _xi); - __m128i _s1 = _mm_madd_epi16(_w1, _xi); - __m128i _s2 = _mm_madd_epi16(_w2, _xi); - __m128i _s3 = _mm_madd_epi16(_w3, _xi); - _sum0 = _mm_add_epi32(_sum0, _s0); - _sum1 = _mm_add_epi32(_sum1, _s1); - _sum2 = _mm_add_epi32(_sum2, _s2); - _sum3 = _mm_add_epi32(_sum3, _s3); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _w0, _xi); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _w1, _xi); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _w2, _xi); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _w3, _xi); kptr += 32; } @@ -2757,15 +2730,8 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); #endif -#if __XOP__ - _sum0 = _mm_maddd_epi16(_w0, _xi, _sum0); - _sum1 = _mm_maddd_epi16(_w1, _xi, _sum1); -#else - __m128i _s0 = _mm_madd_epi16(_w0, _xi); - __m128i _s1 = _mm_madd_epi16(_w1, _xi); - _sum0 = _mm_add_epi32(_sum0, _s0); - _sum1 = _mm_add_epi32(_sum1, _s1); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _w0, _xi); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _w1, _xi); kptr += 16; } @@ -2794,11 +2760,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _xi = _mm_unpacklo_epi8(_xi, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi)); #endif -#if __XOP__ - _lstm_IFOGx0 = _mm_maddd_epi16(_w, _xi, _lstm_IFOGx0); -#else - _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _mm_madd_epi16(_w, _xi)); -#endif + _lstm_IFOGx0 = _mm_comp_dpwssd_epi32(_lstm_IFOGx0, _w, _xi); kptr += 8; } @@ -2844,10 +2806,10 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - _sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0); - _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1); - _sum2 = _mm_dpbusd_epi32(_sum2, _h_cont, _w2); - _sum3 = _mm_dpbusd_epi32(_sum3, _h_cont, _w3); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _h_cont, _w1); + _sum2 = _mm_comp_dpbusd_epi32(_sum2, _h_cont, _w2); + _sum3 = _mm_comp_dpbusd_epi32(_sum3, _h_cont, _w3); kptr += 64; } @@ -2868,8 +2830,8 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - _sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0); - _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1); + _sum0 = _mm_comp_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm_comp_dpbusd_epi32(_sum1, _h_cont, _w1); kptr += 32; } @@ -2886,7 +2848,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #ifdef _MSC_VER _h_cont = _mm_add_epi32(_h_cont, _mm_set1_epi8(127)); #endif - _lstm_IFOGh0 = _mm_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); + _lstm_IFOGh0 = _mm_comp_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); kptr += 16; } @@ -2921,21 +2883,10 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _w3 = _mm_unpacklo_epi8(_w3, _mm_cmpgt_epi8(_mm_setzero_si128(), _w3)); #endif -#if __XOP__ - _sum0 = _mm_maddd_epi16(_w0, _h_cont, _sum0); - _sum1 = _mm_maddd_epi16(_w1, _h_cont, _sum1); - _sum2 = _mm_maddd_epi16(_w2, _h_cont, _sum2); - _sum3 = _mm_maddd_epi16(_w3, _h_cont, _sum3); -#else - __m128i _s0 = _mm_madd_epi16(_w0, _h_cont); - __m128i _s1 = _mm_madd_epi16(_w1, _h_cont); - __m128i _s2 = _mm_madd_epi16(_w2, _h_cont); - __m128i _s3 = _mm_madd_epi16(_w3, _h_cont); - _sum0 = _mm_add_epi32(_sum0, _s0); - _sum1 = _mm_add_epi32(_sum1, _s1); - _sum2 = _mm_add_epi32(_sum2, _s2); - _sum3 = _mm_add_epi32(_sum3, _s3); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _w0, _h_cont); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _w1, _h_cont); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _w2, _h_cont); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _w3, _h_cont); kptr += 32; } @@ -2966,15 +2917,8 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); #endif -#if __XOP__ - _sum0 = _mm_maddd_epi16(_w0, _h_cont, _sum0); - _sum1 = _mm_maddd_epi16(_w1, _h_cont, _sum1); -#else - __m128i _s0 = _mm_madd_epi16(_w0, _h_cont); - __m128i _s1 = _mm_madd_epi16(_w1, _h_cont); - _sum0 = _mm_add_epi32(_sum0, _s0); - _sum1 = _mm_add_epi32(_sum1, _s1); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _w0, _h_cont); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _w1, _h_cont); kptr += 16; } @@ -3003,11 +2947,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _h_cont = _mm_unpacklo_epi8(_h_cont, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont)); #endif -#if __XOP__ - _lstm_IFOGh0 = _mm_maddd_epi16(_w, _h_cont, _lstm_IFOGh0); -#else - _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _mm_madd_epi16(_w, _h_cont)); -#endif + _lstm_IFOGh0 = _mm_comp_dpwssd_epi32(_lstm_IFOGh0, _w, _h_cont); kptr += 8; } diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index e3d1d11fbc53..f25b06745e84 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -122,6 +122,28 @@ static NCNN_FORCEINLINE void transpose8x4_epi16(__m128i& _r0, __m128i& _r1, __m1 _r3 = _mm_unpackhi_epi32(_tmp1, _tmp3); } +static NCNN_FORCEINLINE void transpose8x4_epi8(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpacklo_epi8(_r2, _r3); + + _r0 = _mm_unpacklo_epi16(_tmp0, _tmp1); + _r1 = _mm_unpackhi_epi16(_tmp0, _tmp1); +} + +static NCNN_FORCEINLINE void transpose16x4_epi8(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi8(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi8(_r2, _r3); + + _r0 = _mm_unpacklo_epi16(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi16(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi16(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi16(_tmp1, _tmp3); +} + static NCNN_FORCEINLINE float _mm_reduce_add_ps(__m128 x128) { const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); @@ -267,83 +289,96 @@ static NCNN_FORCEINLINE __m128i float2bfloat_sse(const __m128& v0, const __m128& return _v; } -#ifndef __FMA__ -static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) -{ - return _mm_add_ps(_mm_mul_ps(_a, _b), _c); -} -static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) -{ - return _mm_sub_ps(_c, _mm_mul_ps(_a, _b)); -} -static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) -{ - return _mm_sub_ps(_mm_mul_ps(_a, _b), _c); -} -static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) -{ - return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1))); -} -#else static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) { +#if __FMA__ return _mm_fmadd_ps(_a, _b, _c); +#else + return _mm_add_ps(_mm_mul_ps(_a, _b), _c); +#endif } + static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) { // return -a * b + c +#if __FMA__ return _mm_fnmadd_ps(_a, _b, _c); +#else + return _mm_sub_ps(_c, _mm_mul_ps(_a, _b)); +#endif } + static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) { +#if __FMA__ return _mm_fmsub_ps(_a, _b, _c); +#else + return _mm_sub_ps(_mm_mul_ps(_a, _b), _c); +#endif } + static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) { +#if __FMA__ return _mm_fnmsub_ps(_a, _b, _c); +#else + return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1))); +#endif } -#endif // !__FMA__ -#if __AVX__ -#ifndef __FMA__ -static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) -{ - return _mm256_add_ps(_mm256_mul_ps(_a, _b), _c); -} -static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) -{ - return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b)); -} -static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) -{ - return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c); -} -static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +static NCNN_FORCEINLINE __m128i _mm_comp_dpwssd_epi32(__m128i src, __m128i a, __m128i b) { - return _mm256_sub_ps(_c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1))); -} +#if __AVX512VNNI__ + return _mm_dpwssd_epi32(src, a, b); +#elif __AVXVNNI__ + return _mm_dpwssd_avx_epi32(src, a, b); +#elif __XOP__ + return _mm_maddd_epi16(a, b, src); #else + return _mm_add_epi32(src, _mm_madd_epi16(a, b)); +#endif +} + +#if __AVX__ static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return a * b + c +#if __FMA__ return _mm256_fmadd_ps(_a, _b, _c); +#else + return _mm256_add_ps(_mm256_mul_ps(_a, _b), _c); +#endif } + static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return -a * b + c +#if __FMA__ return _mm256_fnmadd_ps(_a, _b, _c); +#else + return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b)); +#endif } + static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return a * b - c +#if __FMA__ return _mm256_fmsub_ps(_a, _b, _c); +#else + return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c); +#endif } + static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) { // return -(a * b) - c +#if __FMA__ return _mm256_fnmsub_ps(_a, _b, _c); -} +#else + return _mm256_sub_ps(_c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1))); #endif +} static NCNN_FORCEINLINE __m256 _mm256_fmadd_1_ps(const __m256& a, const __m256& b, float c) { @@ -650,6 +685,16 @@ static NCNN_FORCEINLINE __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2 _mm256_castps256_ps128(s0123)); } +static NCNN_FORCEINLINE __m256 combine4x2_ps(__m128 a, __m128 b) +{ + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); +} + +static NCNN_FORCEINLINE __m256i combine4x2_epi32(__m128i a, __m128i b) +{ + return _mm256_insertf128_si256(_mm256_castsi128_si256(a), b, 1); +} + static NCNN_FORCEINLINE float _mm256_reduce_add_ps(__m256 x) { /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ @@ -841,6 +886,81 @@ static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256& } #if __AVX2__ +static NCNN_FORCEINLINE __m256i _mm256_comp_dpwssd_epi32(__m256i src, __m256i a, __m256i b) +{ +#if __AVX512VNNI__ + return _mm256_dpwssd_epi32(src, a, b); +#elif __AVXVNNI__ + return _mm256_dpwssd_avx_epi32(src, a, b); +#else + return _mm256_add_epi32(src, _mm256_madd_epi16(a, b)); +#endif +} + +#if __AVX512VNNI__ || __AVXVNNI__ +static NCNN_FORCEINLINE __m128i _mm_comp_dpbusd_epi32(__m128i src, __m128i a, __m128i b) +{ +#if __AVX512VNNI__ + return _mm_dpbusd_epi32(src, a, b); +#else + return _mm_dpbusd_avx_epi32(src, a, b); +#endif +} + +static NCNN_FORCEINLINE __m256i _mm256_comp_dpbusd_epi32(__m256i src, __m256i a, __m256i b) +{ +#if __AVX512VNNI__ + return _mm256_dpbusd_epi32(src, a, b); +#else + return _mm256_dpbusd_avx_epi32(src, a, b); +#endif +} +#endif // __AVX512VNNI__ || __AVXVNNI__ + +static NCNN_FORCEINLINE __m128i _mm_comp_cvtepi32_epi16(__m128i a) +{ +#if __AVX512F__ + return _mm_cvtepi32_epi16(a); +#else + __m128i _si = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0); + return _mm_shuffle_epi8(a, _si); +#endif +} + +static NCNN_FORCEINLINE __m128i _mm256_comp_cvtepi32_epi16(__m256i a) +{ +#if __AVX512F__ + return _mm256_cvtepi32_epi16(a); +#else + __m128i _si = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0); + __m256i _t = _mm256_shuffle_epi8(a, combine4x2_epi32(_si, _si)); + _t = _mm256_permute4x64_epi64(_t, _MM_SHUFFLE(3, 1, 2, 0)); + return _mm256_castsi256_si128(_t); +#endif +} + +static NCNN_FORCEINLINE __m128i _mm_comp_cvtepi32_epi8(__m128i a) +{ +#if __AVX512F__ + return _mm_cvtepi32_epi8(a); +#else + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + return _mm_shuffle_epi8(a, _si); +#endif +} + +static NCNN_FORCEINLINE __m128i _mm256_comp_cvtepi32_epi8(__m256i a) +{ +#if __AVX512F__ + return _mm256_cvtepi32_epi8(a); +#else + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m256i _t = _mm256_shuffle_epi8(a, combine4x2_epi32(_si, _si)); + _t = _mm256_permute4x64_epi64(_t, _MM_SHUFFLE(3, 1, 2, 0)); + return _mm_shuffle_epi32(_mm256_castsi256_si128(_t), _MM_SHUFFLE(3, 1, 2, 0)); +#endif +} + static NCNN_FORCEINLINE void transpose8x2_epi32(__m256i& _r0, __m256i& _r1) { __m256i _tmp0 = _mm256_unpacklo_epi32(_r0, _r1); @@ -890,6 +1010,15 @@ static NCNN_FORCEINLINE void transpose16x8_epi16(__m256i& _r0, __m256i& _r1, __m } #if __AVX512F__ +static NCNN_FORCEINLINE __m512i _mm512_comp_dpwssd_epi32(__m512i src, __m512i a, __m512i b) +{ +#if __AVX512VNNI__ + return _mm512_dpwssd_epi32(src, a, b); +#else + return _mm512_add_epi32(src, _mm512_madd_epi16(a, b)); +#endif +} + static NCNN_FORCEINLINE void transpose16x16_ps(__m512& _r0, __m512& _r1, __m512& _r2, __m512& _r3, __m512& _r4, __m512& _r5, __m512& _r6, __m512& _r7, __m512& _r8, __m512& _r9, __m512& _ra, __m512& _rb, __m512& _rc, __m512& _rd, __m512& _re, __m512& _rf) { @@ -1291,6 +1420,30 @@ static NCNN_FORCEINLINE float _mm512_comp_reduce_max_ps(__m512 x) return _mm_cvtss_f32(x32); } +static NCNN_FORCEINLINE __m512 combine8x2_ps(__m256 a, __m256 b) +{ + return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); +} + +static NCNN_FORCEINLINE __m512 combine4x4_ps(__m128 a, __m128 b, __m128 c, __m128 d) +{ + __m256 ab = combine4x2_ps(a, b); + __m256 cd = combine4x2_ps(c, d); + return combine8x2_ps(ab, cd); +} + +static NCNN_FORCEINLINE __m512i combine8x2_epi32(__m256i a, __m256i b) +{ + return _mm512_inserti32x8(_mm512_castsi256_si512(a), b, 1); +} + +static NCNN_FORCEINLINE __m512i combine4x4_epi32(__m128i a, __m128i b, __m128i c, __m128i d) +{ + __m256i ab = combine4x2_epi32(a, b); + __m256i cd = combine4x2_epi32(c, d); + return combine8x2_epi32(ab, cd); +} + static NCNN_FORCEINLINE __m128i float2int8_avx512(const __m512& _v0) { // _MM_FROUND_TO_NEAREST_INT round to even diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index d7c6c531a05f..1c3f136d2825 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -15,6 +15,183 @@ #include "testutil.h" #if NCNN_INT8 +static void RandomizeA(ncnn::Mat& m, int transA, float absmax) +{ + if (transA == 0) + { + const int h = m.dims == 3 ? m.c : m.h; + for (int i = 0; i < h; i++) + { + float* p = m.dims == 3 ? m.channel(i) : m.row(i); + float randabsmax = RandomFloat(absmax * 0.5f, absmax); + randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax)); + randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax)); + + for (int j = 0; j < m.w; j++) + { + p[j] = RandomFloat(-randabsmax, randabsmax); + } + + // set random a and b + p[RandomInt(0, m.w - 1)] = -randabsmax; + p[RandomInt(0, m.w - 1)] = randabsmax; + + // drop 0.45 ~ 0.55 + for (int j = 0; j < m.w; j++) + { + float v = p[j] * (127.f / randabsmax); + float vv = fabs(v - (int)v); + + float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + float hv = hp * (127.f / randabsmax); + float hvv = fabs(hv - (int)hv); + + float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + float bv = bp * (127.f / randabsmax); + float bvv = fabs(bv - (int)bv); + + while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f)) + { + p[j] = RandomFloat(-randabsmax, randabsmax); + v = p[j] * (127.f / randabsmax); + vv = fabs(v - (int)v); + + hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + hv = hp * (127.f / randabsmax); + hvv = fabs(hv - (int)hv); + + bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + bv = bp * (127.f / randabsmax); + bvv = fabs(bv - (int)bv); + } + } + } + } + else // if (transA == 1) + { + std::vector randabsmaxes(m.w); + for (int j = 0; j < m.w; j++) + { + float randabsmax = RandomFloat(absmax * 0.5f, absmax); + randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax)); + randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax)); + randabsmaxes[j] = randabsmax; + } + + const int h = m.dims == 3 ? m.c : m.h; + for (int i = 0; i < h; i++) + { + float* p = m.dims == 3 ? m.channel(i) : m.row(i); + for (int j = 0; j < m.w; j++) + { + const float randabsmax = randabsmaxes[j]; + p[j] = RandomFloat(-randabsmax, randabsmax); + } + + // drop 0.45 ~ 0.55 + for (int j = 0; j < m.w; j++) + { + const float randabsmax = randabsmaxes[j]; + float v = p[j] * (127.f / randabsmax); + float vv = fabs(v - (int)v); + + float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + float hv = hp * (127.f / randabsmax); + float hvv = fabs(hv - (int)hv); + + float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + float bv = bp * (127.f / randabsmax); + float bvv = fabs(bv - (int)bv); + + while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f)) + { + p[j] = RandomFloat(-randabsmax, randabsmax); + v = p[j] * (127.f / randabsmax); + vv = fabs(v - (int)v); + + hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + hv = hp * (127.f / randabsmax); + hvv = fabs(hv - (int)hv); + + bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + bv = bp * (127.f / randabsmax); + bvv = fabs(bv - (int)bv); + } + } + } + + for (int j = 0; j < m.w; j++) + { + const int randi0 = RandomInt(0, h - 1); + const int randi1 = RandomInt(0, h - 1); + float* p0 = m.dims == 3 ? m.channel(randi0) : m.row(randi0); + float* p1 = m.dims == 3 ? m.channel(randi1) : m.row(randi1); + + const float randabsmax = randabsmaxes[j]; + + // set random a and b + p0[j] = -randabsmax; + p1[j] = randabsmax; + } + } +} + +static void RandomizeB(ncnn::Mat& m, float absmax) +{ + absmax = ncnn::float16_to_float32(ncnn::float32_to_float16(absmax)); + absmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(absmax)); + + const int h = m.dims == 3 ? m.c : m.h; + float* p = m; + for (int i = 0; i < h; i++) + { + float* p = m.dims == 3 ? m.channel(i) : m.row(i); + for (int j = 0; j < m.w; j++) + { + p[j] = RandomFloat(-absmax, absmax); + + // drop 0.45 ~ 0.55 + float v = p[j] * (127.f / absmax); + float vv = fabs(v - (int)v); + + float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + float hv = hp * (127.f / absmax); + float hvv = fabs(hv - (int)hv); + + float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + float bv = bp * (127.f / absmax); + float bvv = fabs(bv - (int)bv); + + while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f)) + { + p[j] = RandomFloat(-absmax, absmax); + v = p[j] * (127.f / absmax); + vv = fabs(v - (int)v); + + hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + hv = hp * (127.f / absmax); + hvv = fabs(hv - (int)hv); + + bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + bv = bp * (127.f / absmax); + bvv = fabs(bv - (int)bv); + } + } + } + + // set random a and b + if (m.dims == 3) + { + m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax; + m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax; + } + else + { + m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax; + m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax; + } +} + static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M) { ncnn::ParamDict pd; @@ -35,18 +212,21 @@ static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int tran pd.set(18, 2); // int8_scale_term std::vector weights; - if (constantA) weights.push_back(transA ? (output_N1M ? RandomS8Mat(M, 1, K) : RandomS8Mat(M, K)) : (output_N1M ? RandomS8Mat(K, 1, M) : RandomS8Mat(K, M))); - if (constantB) weights.push_back(transB ? (output_N1M ? RandomS8Mat(K, 1, N) : RandomS8Mat(K, N)) : (output_N1M ? RandomS8Mat(N, 1, K) : RandomS8Mat(N, K))); + if (constantA) weights.push_back(transA ? RandomS8Mat(M, K) : RandomS8Mat(K, M)); + if (constantB) weights.push_back(transB ? RandomS8Mat(K, N) : RandomS8Mat(N, K)); if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f)); if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); std::vector a; - if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M))); - if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K))); - - for (size_t i = 0; i < a.size(); i++) + if (!constantA) { - Randomize(a[i], -10.f, 10.f); + a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M))); + RandomizeA(a[a.size() - 1], transA, 10.f); + } + if (!constantB) + { + a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K))); + RandomizeB(a[a.size() - 1], 10.f); } int ret = test_layer("Gemm", pd, weights, a); @@ -118,14 +298,17 @@ static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float al if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); std::vector a; - if (!constantA) a.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M)); - if (!constantB) a.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K)); - if (!constantC) a.push_back(C); - - for (size_t i = 0; i < a.size(); i++) + if (!constantA) { - Randomize(a[i], -10.f, 10.f); + a.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M)); + RandomizeA(a[a.size() - 1], transA, 10.f); } + if (!constantB) + { + a.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K)); + RandomizeB(a[a.size() - 1], 10.f); + } + if (!constantC) a.push_back(C); int ret = test_layer("Gemm", pd, weights, a); if (ret != 0) @@ -156,18 +339,21 @@ static int test_gemm_int8_fp16s(int M, int N, int K, float alpha, int transA, in pd.set(18, 2); // int8_scale_term std::vector weights; - if (constantA) weights.push_back(transA ? (output_N1M ? RandomS8Mat(M, 1, K) : RandomS8Mat(M, K)) : (output_N1M ? RandomS8Mat(K, 1, M) : RandomS8Mat(K, M))); - if (constantB) weights.push_back(transB ? (output_N1M ? RandomS8Mat(K, 1, N) : RandomS8Mat(K, N)) : (output_N1M ? RandomS8Mat(N, 1, K) : RandomS8Mat(N, K))); + if (constantA) weights.push_back(transA ? RandomS8Mat(M, K) : RandomS8Mat(K, M)); + if (constantB) weights.push_back(transB ? RandomS8Mat(K, N) : RandomS8Mat(N, K)); if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f)); if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); std::vector a; - if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M))); - if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K))); - - for (size_t i = 0; i < a.size(); i++) + if (!constantA) + { + a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M))); + RandomizeA(a[a.size() - 1], transA, 10.f); + } + if (!constantB) { - Randomize(a[i], -10.f, 10.f); + a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K))); + RandomizeB(a[a.size() - 1], 10.f); } ncnn::Option opt; diff --git a/tests/test_gemm_4.cpp b/tests/test_gemm_4.cpp index 3b25cf9e9f97..bde08eecb82c 100644 --- a/tests/test_gemm_4.cpp +++ b/tests/test_gemm_4.cpp @@ -15,6 +15,183 @@ #include "testutil.h" #if NCNN_INT8 +static void RandomizeA(ncnn::Mat& m, int transA, float absmax) +{ + if (transA == 0) + { + const int h = m.dims == 3 ? m.c : m.h; + for (int i = 0; i < h; i++) + { + float* p = m.dims == 3 ? m.channel(i) : m.row(i); + float randabsmax = RandomFloat(absmax * 0.5f, absmax); + randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax)); + randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax)); + + for (int j = 0; j < m.w; j++) + { + p[j] = RandomFloat(-randabsmax, randabsmax); + } + + // set random a and b + p[RandomInt(0, m.w - 1)] = -randabsmax; + p[RandomInt(0, m.w - 1)] = randabsmax; + + // drop 0.45 ~ 0.55 + for (int j = 0; j < m.w; j++) + { + float v = p[j] * (127.f / randabsmax); + float vv = fabs(v - (int)v); + + float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + float hv = hp * (127.f / randabsmax); + float hvv = fabs(hv - (int)hv); + + float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + float bv = bp * (127.f / randabsmax); + float bvv = fabs(bv - (int)bv); + + while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f)) + { + p[j] = RandomFloat(-randabsmax, randabsmax); + v = p[j] * (127.f / randabsmax); + vv = fabs(v - (int)v); + + hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + hv = hp * (127.f / randabsmax); + hvv = fabs(hv - (int)hv); + + bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + bv = bp * (127.f / randabsmax); + bvv = fabs(bv - (int)bv); + } + } + } + } + else // if (transA == 1) + { + std::vector randabsmaxes(m.w); + for (int j = 0; j < m.w; j++) + { + float randabsmax = RandomFloat(absmax * 0.5f, absmax); + randabsmax = ncnn::float16_to_float32(ncnn::float32_to_float16(randabsmax)); + randabsmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(randabsmax)); + randabsmaxes[j] = randabsmax; + } + + const int h = m.dims == 3 ? m.c : m.h; + for (int i = 0; i < h; i++) + { + float* p = m.dims == 3 ? m.channel(i) : m.row(i); + for (int j = 0; j < m.w; j++) + { + const float randabsmax = randabsmaxes[j]; + p[j] = RandomFloat(-randabsmax, randabsmax); + } + + // drop 0.45 ~ 0.55 + for (int j = 0; j < m.w; j++) + { + const float randabsmax = randabsmaxes[j]; + float v = p[j] * (127.f / randabsmax); + float vv = fabs(v - (int)v); + + float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + float hv = hp * (127.f / randabsmax); + float hvv = fabs(hv - (int)hv); + + float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + float bv = bp * (127.f / randabsmax); + float bvv = fabs(bv - (int)bv); + + while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f)) + { + p[j] = RandomFloat(-randabsmax, randabsmax); + v = p[j] * (127.f / randabsmax); + vv = fabs(v - (int)v); + + hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + hv = hp * (127.f / randabsmax); + hvv = fabs(hv - (int)hv); + + bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + bv = bp * (127.f / randabsmax); + bvv = fabs(bv - (int)bv); + } + } + } + + for (int j = 0; j < m.w; j++) + { + const int randi0 = RandomInt(0, h - 1); + const int randi1 = RandomInt(0, h - 1); + float* p0 = m.dims == 3 ? m.channel(randi0) : m.row(randi0); + float* p1 = m.dims == 3 ? m.channel(randi1) : m.row(randi1); + + const float randabsmax = randabsmaxes[j]; + + // set random a and b + p0[j] = -randabsmax; + p1[j] = randabsmax; + } + } +} + +static void RandomizeB(ncnn::Mat& m, float absmax) +{ + absmax = ncnn::float16_to_float32(ncnn::float32_to_float16(absmax)); + absmax = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(absmax)); + + const int h = m.dims == 3 ? m.c : m.h; + float* p = m; + for (int i = 0; i < h; i++) + { + float* p = m.dims == 3 ? m.channel(i) : m.row(i); + for (int j = 0; j < m.w; j++) + { + p[j] = RandomFloat(-absmax, absmax); + + // drop 0.45 ~ 0.55 + float v = p[j] * (127.f / absmax); + float vv = fabs(v - (int)v); + + float hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + float hv = hp * (127.f / absmax); + float hvv = fabs(hv - (int)hv); + + float bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + float bv = bp * (127.f / absmax); + float bvv = fabs(bv - (int)bv); + + while ((vv > 0.45f && vv < 0.55f) || (hvv > 0.45f && hvv < 0.55f) || (bvv > 0.45f && bvv < 0.55f)) + { + p[j] = RandomFloat(-absmax, absmax); + v = p[j] * (127.f / absmax); + vv = fabs(v - (int)v); + + hp = ncnn::float16_to_float32(ncnn::float32_to_float16(p[j])); + hv = hp * (127.f / absmax); + hvv = fabs(hv - (int)hv); + + bp = ncnn::bfloat16_to_float32(ncnn::float32_to_bfloat16(p[j])); + bv = bp * (127.f / absmax); + bvv = fabs(bv - (int)bv); + } + } + } + + // set random a and b + if (m.dims == 3) + { + m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax; + m.channel(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax; + } + else + { + m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = -absmax; + m.row(RandomInt(0, h - 1))[RandomInt(0, m.w - 1)] = absmax; + } +} + static int test_gemm_int8(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose) { ncnn::ParamDict pd; @@ -35,8 +212,8 @@ static int test_gemm_int8(int M, int N, int K, int TILE_M, int TILE_N, int TILE_ a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M); a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K); - Randomize(a[0], -10.f, 10.f); - Randomize(a[1], -10.f, 10.f); + RandomizeA(a[0], transA, 10.f); + RandomizeB(a[1], 10.f); int ret = test_layer("Gemm", pd, weights, a); if (ret != 0)