diff --git a/.ci/linux-x64-cpu-gcc.yml b/.ci/linux-x64-cpu-gcc.yml index 4f138d9d080..f0bf4ce1ae1 100644 --- a/.ci/linux-x64-cpu-gcc.yml +++ b/.ci/linux-x64-cpu-gcc.yml @@ -117,3 +117,11 @@ jobs: cmake --build . -j $(nproc) - name: test-simplestl-simpleomp run: cd build-simplestl-simpleomp && ctest --output-on-failure -j $(nproc) + - name: build-simplestl-simplemath + run: | + mkdir build-simplestl-simplemath && cd build-simplestl-simplemath + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host-c.gcc.toolchain.cmake -DNCNN_STDIO=ON -DNCNN_STRING=ON -DNCNN_SIMPLESTL=ON -DNCNN_SIMPLEMATH=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j $(nproc) + - name: test-simplestl-simplemath + run: cd build-simplestl-simplemath && ctest --output-on-failure -j $(nproc) + diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 267a0afa289..3f116a4fa2e 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -48,6 +48,10 @@ jobs: torchvision-version: 0.15.1 torchvision-cache-key: '0_15_1' + - torch-version: 2.1.0 + torchvision-version: 0.16.0 + torchvision-cache-key: '0_16_0' + runs-on: pool-name: docker container: diff --git a/.ci/test-coverage.yml b/.ci/test-coverage.yml index a693f415883..1c5e72edc7c 100644 --- a/.ci/test-coverage.yml +++ b/.ci/test-coverage.yml @@ -908,3 +908,47 @@ jobs: lcov --list lcov.info - name: codecov run: ./codecov -t ${{settings.CODECOV_TOKEN.access_token}} -C ${{ ci.sha }} -B ${{ ci.head_ref }} -f build/lcov.info + + linux-gcc-x64-simplemath: + name: linux-gcc-x64-simplemath + + runs-on: + pool-name: docker + container: + image: bkci/ci:ubuntu + steps: + - name: checkout + checkout: self + with: + strategy: FRESH_CHECKOUT + enableSubmodule: false + enableGitLfs: false + + - name: install-deps + run: | + apt-get update + apt-get install -y lcov + curl https://uploader.codecov.io/verification.gpg | gpg --no-default-keyring --keyring trustedkeys.gpg --import + curl -Os https://uploader.codecov.io/latest/linux/codecov + curl -Os https://uploader.codecov.io/latest/linux/codecov.SHA256SUM + curl -Os https://uploader.codecov.io/latest/linux/codecov.SHA256SUM.sig + gpgv codecov.SHA256SUM.sig codecov.SHA256SUM + shasum -a 256 -c codecov.SHA256SUM + chmod +x codecov + + - name: build + run: | + mkdir build && cd build + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host-c.gcc.toolchain.cmake -DCMAKE_BUILD_TYPE=debug -DNCNN_COVERAGE=ON -DNCNN_STDIO=ON -DNCNN_STRING=ON -DNCNN_SIMPLESTL=ON -DNCNN_SIMPLEMATH=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j $(nproc) + - name: test + run: cd build && ctest --output-on-failure -j $(nproc) + - name: lcov-collect + run: | + cd build + lcov -d ./src -c -o lcov.info + lcov -r lcov.info '/usr/*' -o lcov.info + lcov -r lcov.info '*/build/*' -o lcov.info + lcov --list lcov.info + - name: codecov + run: ./codecov -t ${{settings.CODECOV_TOKEN.access_token}} -C ${{ ci.sha }} -B ${{ ci.head_ref }} -f build/lcov.info \ No newline at end of file diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 00000000000..31d03ede341 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,24 @@ +cmake: +- cmake/** +- toolchains/** + +doc: docs/** + +python: python/** + +example: examples/** + +test: tests/** + +tool: tools/** +pnnx: tools/pnnx/** + +core: src/* +layer: src/layer/* + +arm: src/layer/arm/** +loongarch: src/layer/loongarch/** +mips: src/layer/mips/** +riscv: src/layer/riscv/** +vulkan: src/layer/vulkan/** +x86: src/layer/x86/** diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index 9ffb978181d..8051371d3e5 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -51,7 +51,7 @@ jobs: rm -rf $GITHUB_WORKSPACE/clang-format-install export PATH=~/bin:$PATH sh codeformat.sh - - uses: stefanzweifel/git-auto-commit-action@v4 + - uses: stefanzweifel/git-auto-commit-action@v5 with: commit_message: apply code-format changes diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 00000000000..889c41b54c9 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,12 @@ +name: labeler +on: [pull_request_target] + +permissions: + contents: read + pull-requests: write + +jobs: + label: + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v4 diff --git a/.github/workflows/linux-aarch64-cpu-gcc.yml b/.github/workflows/linux-aarch64-cpu-gcc.yml index 46179097aec..a791da6c26a 100644 --- a/.github/workflows/linux-aarch64-cpu-gcc.yml +++ b/.github/workflows/linux-aarch64-cpu-gcc.yml @@ -86,6 +86,17 @@ jobs: export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH cd build-noint8 TESTS_EXECUTABLE_LOADER=qemu-aarch64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-L;/usr/aarch64-linux-gnu" ctest --output-on-failure -j 2 + + - name: build-simplestl-simplemath + run: | + mkdir build-simplestl-simplemath && cd build-simplestl-simplemath + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu-c.toolchain.cmake -DNCNN_STDIO=ON -DNCNN_STRING=ON -DNCNN_SIMPLESTL=ON -DNCNN_SIMPLEMATH=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j 2 + - name: test-simplestl-simplemath + run: | + export PATH=$GITHUB_WORKSPACE/qemu-install/bin:$PATH + cd build-simplestl-simplemath + TESTS_EXECUTABLE_LOADER=qemu-aarch64 TESTS_EXECUTABLE_LOADER_ARGUMENTS="-L;/usr/aarch64-linux-gnu" ctest --output-on-failure -j 2 linux-gcc-arm82: runs-on: ubuntu-20.04 diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index 9d99ffdd630..20dd5ddc4cc 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -5,6 +5,14 @@ on: tags: - '*' +env: + DEVELOPER_DIR: /Applications/Xcode_13.4.1.app/Contents/Developer + MAC_DEPLOYMENT_TARGET: '10.9' + MAC_ARM64_DEPLOYMENT_TARGET: '11.0' + ENABLE_BITCODE: OFF + ENABLE_ARC: OFF + ENABLE_VISIBILITY: OFF + jobs: build_sdist: name: Build SDist @@ -51,7 +59,6 @@ jobs: - { os: macos-latest, arch: x86_64, build: 'cp*' } - { os: macos-latest, arch: x86_64, build: 'pp*' } - { os: macos-latest, arch: arm64, build: 'cp*' } - - { os: macos-latest, arch: universal2, build: 'cp*' } steps: - uses: actions/checkout@v4 @@ -62,20 +69,223 @@ jobs: with: python-version: '3.x' - - name: brew uninstall libomp + # build wheels for ubuntu-20.04 + - name: Build wheels for ubuntu manylinux + if: matrix.os == 'ubuntu-20.04' && matrix.build != 'cp*-musllinux*' + uses: pypa/cibuildwheel@v2.16.2 + env: + CIBW_ARCHS_LINUX: ${{ matrix.arch }} + CIBW_BUILD: ${{ matrix.build }} + CIBW_BUILD_VERBOSITY: 1 + CIBW_ENVIRONMENT: CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON + VULKAN_SDK=/project/Vulkan-Loader/build/Vulkan-Headers + LD_LIBRARY_PATH=/project/Vulkan-Loader/build/loader + CIBW_BEFORE_ALL: yum -y install libXrandr-devel && + git clone https://github.com/KhronosGroup/Vulkan-Loader.git && + cd Vulkan-Loader && mkdir build && cd build && + ../scripts/update_deps.py && + cmake -DCMAKE_BUILD_TYPE=Release -DVULKAN_HEADERS_INSTALL_DIR=$(pwd)/Vulkan-Headers/build/install .. && + make -j$(nproc) && + cd Vulkan-Headers && + ln -s ../loader lib + with: + output-dir: wheelhouse + + - name: Build wheels for ubuntu musllinux + if: matrix.os == 'ubuntu-20.04' && matrix.build == 'cp*-musllinux*' + uses: pypa/cibuildwheel@v2.16.2 + env: + CIBW_ARCHS_LINUX: ${{ matrix.arch }} + CIBW_BUILD: ${{ matrix.build }} + CIBW_BUILD_VERBOSITY: 1 + CIBW_ENVIRONMENT: CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON + VULKAN_SDK=/project/Vulkan-Loader/build/Vulkan-Headers + LD_LIBRARY_PATH=/project/Vulkan-Loader/build/loader + CIBW_BEFORE_ALL: apk add libxrandr-dev && + git clone https://github.com/KhronosGroup/Vulkan-Loader.git && + cd Vulkan-Loader && mkdir build && cd build && + ../scripts/update_deps.py && + cmake -DCMAKE_BUILD_TYPE=Release -DVULKAN_HEADERS_INSTALL_DIR=$(pwd)/Vulkan-Headers/build/install .. && + make -j$(nproc) && + cd Vulkan-Headers && + ln -s ../loader lib + with: + output-dir: wheelhouse + + # build wheels for windows-2019 + - name: Build wheels for windows amd64 + if: matrix.os == 'windows-2019' && matrix.arch == 'AMD64' + uses: pypa/cibuildwheel@v2.16.2 + env: + CIBW_ARCHS_WINDOWS: ${{ matrix.arch }} + CIBW_BUILD: ${{ matrix.build }} + CIBW_BUILD_VERBOSITY: 1 + CIBW_ENVIRONMENT_WINDOWS: > + PATH="D:\\a\\ncnn\\ncnn\\Vulkan-Loader\\build\\loader\\Release;$PATH" + CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON + VULKAN_SDK=D:/a/ncnn/ncnn/Vulkan-Loader/external/Vulkan-Headers + CIBW_BEFORE_ALL: git clone https://github.com/KhronosGroup/Vulkan-Loader.git && + cd Vulkan-Loader && mkdir build && cd build && + python3 ../scripts/update_deps.py --dir ../external --config release && + cmake -C ../external/helper.cmake -G "Visual Studio 16 2019" -A x64 -DCMAKE_BUILD_TYPE=Release .. && + cmake --build . --config Release && + mklink /d "D:/a/ncnn/ncnn/Vulkan-Loader/external/Vulkan-Headers/build/install/lib" + "D:/a/ncnn/ncnn/Vulkan-Loader/build/loader/Release" + CIBW_BEFORE_BUILD: pip install delvewheel + CIBW_REPAIR_WHEEL_COMMAND: delvewheel repair -w {dest_dir} {wheel} + with: + output-dir: wheelhouse + + - name: Build wheels for windows x86 + if: matrix.os == 'windows-2019' && matrix.arch == 'x86' + uses: pypa/cibuildwheel@v2.16.2 + env: + CIBW_ARCHS_WINDOWS: ${{ matrix.arch }} + CIBW_BUILD: ${{ matrix.build }} + CIBW_BUILD_VERBOSITY: 1 + CIBW_ENVIRONMENT_WINDOWS: > + PATH="D:\\a\\ncnn\\ncnn\\Vulkan-Loader\\build\\loader\\Release;$PATH" + CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON + VULKAN_SDK=D:/a/ncnn/ncnn/Vulkan-Loader/external/Vulkan-Headers + CIBW_BEFORE_ALL: git clone https://github.com/KhronosGroup/Vulkan-Loader.git && + cd Vulkan-Loader && mkdir build && cd build && + python3 ../scripts/update_deps.py --dir ../external --arch ${{ matrix.arch }} --config release && + cmake -C ../external/helper.cmake -G "Visual Studio 16 2019" -A Win32 -DCMAKE_BUILD_TYPE=Release .. && + cmake --build . --config Release && + mklink /d "D:/a/ncnn/ncnn/Vulkan-Loader/external/Vulkan-Headers/build/install/lib" + "D:/a/ncnn/ncnn/Vulkan-Loader/build/loader/Release" + CIBW_BEFORE_BUILD: pip install delvewheel + CIBW_REPAIR_WHEEL_COMMAND: delvewheel repair -w {dest_dir} {wheel} + with: + output-dir: wheelhouse + + - name: Build wheels for windows ARM64 + if: matrix.os == 'windows-2019' && matrix.arch == 'ARM64' + uses: pypa/cibuildwheel@v2.16.2 + env: + CIBW_ARCHS_WINDOWS: ${{ matrix.arch }} + CIBW_BUILD: ${{ matrix.build }} + CIBW_BUILD_VERBOSITY: 1 + CIBW_ENVIRONMENT_WINDOWS: > + PATH="D:\\a\\ncnn\\ncnn\\Vulkan-Loader\\build\\loader\\Release;$PATH" + CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON + VULKAN_SDK=D:/a/ncnn/ncnn/Vulkan-Loader/external/Vulkan-Headers + CIBW_BEFORE_ALL: git clone https://github.com/KhronosGroup/Vulkan-Loader.git && + cd Vulkan-Loader && mkdir build && cd build && + python3 ../scripts/update_deps.py --dir ../external --config release && + cmake -C ../external/helper.cmake -G "Visual Studio 16 2019" -A ARM64 -DCMAKE_BUILD_TYPE=Release -DUSE_MASM=OFF .. && + cmake --build . --config Release && + mklink /d "D:/a/ncnn/ncnn/Vulkan-Loader/external/Vulkan-Headers/build/install/lib" + "D:/a/ncnn/ncnn/Vulkan-Loader/build/loader/Release" + CIBW_BEFORE_BUILD: pip install delvewheel + CIBW_REPAIR_WHEEL_COMMAND: delvewheel repair -w {dest_dir} {wheel} --no-dll "msvcp140.dll;vcomp140.dll" + with: + output-dir: wheelhouse + + # build wheels for macos-latest + - name: cache-openmp for macos if: matrix.os == 'macos-latest' + id: cache-openmp + uses: actions/cache@v3 + with: + path: openmp-install + key: openmp-macos-install-20230504 + + - name: openmp for macos + if: matrix.os == 'macos-latest' && steps.cache-openmp.outputs.cache-hit != 'true' + run: | + wget https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/openmp-11.0.0.src.tar.xz + tar -xf openmp-11.0.0.src.tar.xz + cd openmp-11.0.0.src + sed -i'' -e '/.size __kmp_unnamed_critical_addr/d' runtime/src/z_Linux_asm.S + sed -i'' -e 's/__kmp_unnamed_critical_addr/___kmp_unnamed_critical_addr/g' runtime/src/z_Linux_asm.S + + - name: openmp-build-x86_64 for macos + if: matrix.os == 'macos-latest' && steps.cache-openmp.outputs.cache-hit != 'true' + run: | + cd openmp-11.0.0.src + mkdir -p build-x86_64 && cd build-x86_64 + cmake -DCMAKE_TOOLCHAIN_FILE=$GITHUB_WORKSPACE/toolchains/ios.toolchain.cmake -DPLATFORM=MAC -DARCHS="x86_64" \ + -DDEPLOYMENT_TARGET=$MAC_DEPLOYMENT_TARGET -DENABLE_BITCODE=$ENABLE_BITCODE -DENABLE_ARC=$ENABLE_ARC -DENABLE_VISIBILITY=$ENABLE_VISIBILITY \ + -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=install \ + -DPERL_EXECUTABLE=/usr/local/bin/perl \ + -DLIBOMP_ENABLE_SHARED=OFF -DLIBOMP_OMPT_SUPPORT=OFF -DLIBOMP_USE_HWLOC=OFF .. + cmake --build . -j 3 + cmake --build . --target install + + - name: openmp-build-arm64 for macos + if: matrix.os == 'macos-latest' && steps.cache-openmp.outputs.cache-hit != 'true' + run: | + cd openmp-11.0.0.src + mkdir -p build-arm64 && cd build-arm64 + cmake -DCMAKE_TOOLCHAIN_FILE=$GITHUB_WORKSPACE/toolchains/ios.toolchain.cmake -DPLATFORM=MAC_ARM64 -DARCHS="arm64" \ + -DDEPLOYMENT_TARGET=$MAC_ARM64_DEPLOYMENT_TARGET -DENABLE_BITCODE=$ENABLE_BITCODE -DENABLE_ARC=$ENABLE_ARC -DENABLE_VISIBILITY=$ENABLE_VISIBILITY \ + -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=install \ + -DPERL_EXECUTABLE=/usr/local/bin/perl \ + -DLIBOMP_ENABLE_SHARED=OFF -DLIBOMP_OMPT_SUPPORT=OFF -DLIBOMP_USE_HWLOC=OFF .. + cmake --build . -j 3 + cmake --build . --target install + + - name: openmp-merge-fat-library for macos + if: matrix.os == 'macos-latest' && steps.cache-openmp.outputs.cache-hit != 'true' run: | - brew uninstall --ignore-dependencies libomp + mkdir -p $GITHUB_WORKSPACE/openmp-install + cp -a openmp-11.0.0.src/build-x86_64/install/include $GITHUB_WORKSPACE/openmp-install + mkdir -p $GITHUB_WORKSPACE/openmp-install/lib + lipo -create \ + openmp-11.0.0.src/build-x86_64/install/lib/libomp.a \ + openmp-11.0.0.src/build-arm64/install/lib/libomp.a \ + -o $GITHUB_WORKSPACE/openmp-install/lib/libomp.a - - name: Build wheels - uses: pypa/cibuildwheel@v2.15.0 + - name: install-openmp for macos + if: matrix.os == 'macos-latest' + run: | + sudo cp $GITHUB_WORKSPACE/openmp-install/include/* $DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include + sudo cp $GITHUB_WORKSPACE/openmp-install/lib/libomp.a $DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/lib + + - name: vulkansdk for macos + if: matrix.os == 'macos-latest' + run: | + wget https://sdk.lunarg.com/sdk/download/1.3.236.0/mac/vulkansdk-macos-1.3.236.0.dmg?Human=true -O vulkansdk-macos-1.3.236.0.dmg + hdiutil attach vulkansdk-macos-1.3.236.0.dmg + sudo /Volumes/vulkansdk-macos-1.3.236.0/InstallVulkan.app/Contents/MacOS/InstallVulkan --root $GITHUB_WORKSPACE/vulkansdk-macos-1.3.236.0 --accept-licenses --default-answer --confirm-command install + hdiutil detach /Volumes/vulkansdk-macos-1.3.236.0 + + - name: Build wheels for macos x86_64 + if: matrix.os == 'macos-latest' && matrix.arch == 'x86_64' + uses: pypa/cibuildwheel@v2.16.2 + env: + CIBW_ARCHS_MACOS: ${{ matrix.arch }} + CIBW_BUILD: ${{ matrix.build }} + CIBW_BUILD_VERBOSITY: 1 + CIBW_ENVIRONMENT: CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON VULKAN_SDK=$GITHUB_WORKSPACE/vulkansdk-macos-1.3.236.0/macOS + CMAKE_TOOLCHAIN_FILE=$GITHUB_WORKSPACE/toolchains/ios.toolchain.cmake PLATFORM=MAC ARCHS="x86_64" + DEPLOYMENT_TARGET="10.9" ENABLE_BITCODE=OFF ENABLE_ARC=OFF ENABLE_VISIBILITY=OFF + OpenMP_C_FLAGS="-Xclang -fopenmp" OpenMP_CXX_FLAGS="-Xclang -fopenmp" + OpenMP_C_LIB_NAMES="libomp" OpenMP_CXX_LIB_NAMES="libomp" + OpenMP_libomp_LIBRARY="$DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/lib/libomp.a" + Vulkan_INCLUDE_DIR=$GITHUB_WORKSPACE/vulkansdk-macos-1.3.236.0/MoltenVK/include + Vulkan_LIBRARY=$GITHUB_WORKSPACE/vulkansdk-macos-1.3.236.0/MoltenVK/dylib/macOS/libMoltenVK.dylib + with: + output-dir: wheelhouse + + - name: Build wheels for macos arm64 + if: matrix.os == 'macos-latest' && matrix.arch == 'arm64' + uses: pypa/cibuildwheel@v2.16.2 env: CIBW_ARCHS_MACOS: ${{ matrix.arch }} - CIBW_ARCHS_LINUX: ${{ matrix.arch }} - CIBW_ARCHS_WINDOWS: ${{ matrix.arch }} CIBW_BUILD: ${{ matrix.build }} CIBW_BUILD_VERBOSITY: 1 - CIBW_ENVIRONMENT: CMAKE_BUILD_PARALLEL_LEVEL=2 + CIBW_ENVIRONMENT: CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON VULKAN_SDK=$GITHUB_WORKSPACE/vulkansdk-macos-1.3.236.0/macOS + CMAKE_TOOLCHAIN_FILE=$GITHUB_WORKSPACE/toolchains/ios.toolchain.cmake PLATFORM=MAC_ARM64 ARCHS="arm64" + DEPLOYMENT_TARGET="11.0" ENABLE_BITCODE=OFF ENABLE_ARC=OFF ENABLE_VISIBILITY=OFF + OpenMP_C_FLAGS="-Xclang -fopenmp" OpenMP_CXX_FLAGS="-Xclang -fopenmp" + OpenMP_C_LIB_NAMES="libomp" OpenMP_CXX_LIB_NAMES="libomp" + OpenMP_libomp_LIBRARY="$DEVELOPER_DIR/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/lib/libomp.a" + Vulkan_INCLUDE_DIR=$GITHUB_WORKSPACE/vulkansdk-macos-1.3.236.0/MoltenVK/include + Vulkan_LIBRARY=$GITHUB_WORKSPACE/vulkansdk-macos-1.3.236.0/MoltenVK/dylib/macOS/libMoltenVK.dylib + with: + output-dir: wheelhouse - name: Show files run: ls -lh wheelhouse @@ -98,7 +308,11 @@ jobs: fail-fast: false matrix: arch: [aarch64, ppc64le, s390x] - build: ['cp36-*', 'cp37-*', 'cp38-*', 'cp39-*', 'cp310-*', 'cp311-*', 'cp312-*'] + build: [ 'cp36-manylinux*', 'cp37-manylinux*', 'cp38-manylinux*', + 'cp39-manylinux*', 'cp310-manylinux*', 'cp311-manylinux*', + 'cp312-manylinux*', 'cp36-musllinux*', 'cp37-musllinux*', + 'cp38-musllinux*', 'cp39-musllinux*', 'cp310-musllinux*', + 'cp311-musllinux*', 'cp312-musllinux*' ] include: - arch: aarch64 build: 'pp37-*' @@ -123,13 +337,50 @@ jobs: with: platforms: all - - name: Build wheels - uses: pypa/cibuildwheel@v2.15.0 + - name: Build wheels for manylinux with qemu + if: (matrix.build != 'cp36-musllinux*') && (matrix.build != 'cp37-musllinux*') && + (matrix.build != 'cp38-musllinux*') && (matrix.build != 'cp39-musllinux*') && + (matrix.build != 'cp310-musllinux*') && (matrix.build != 'cp311-musllinux*') && + (matrix.build != 'cp312-musllinux*') + uses: pypa/cibuildwheel@v2.16.2 env: CIBW_ARCHS_LINUX: ${{ matrix.arch }} CIBW_BUILD: ${{ matrix.build }} CIBW_BUILD_VERBOSITY: 1 - CIBW_ENVIRONMENT: CMAKE_BUILD_PARALLEL_LEVEL=2 + CIBW_ENVIRONMENT: CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON VULKAN_SDK=/project/Vulkan-Loader/build/Vulkan-Headers + LD_LIBRARY_PATH=/project/Vulkan-Loader/build/loader + CIBW_BEFORE_ALL: yum -y install libXrandr-devel && + git clone https://github.com/KhronosGroup/Vulkan-Loader.git && + cd Vulkan-Loader && mkdir build && cd build && + ../scripts/update_deps.py && + cmake -DCMAKE_BUILD_TYPE=Release -DVULKAN_HEADERS_INSTALL_DIR=$(pwd)/Vulkan-Headers/build/install .. && + make -j$(nproc) && + cd Vulkan-Headers && + ln -s ../loader lib + with: + output-dir: wheelhouse + + - name: Build wheels for musllinux with qemu + if: (matrix.build == 'cp36-musllinux*') || (matrix.build == 'cp37-musllinux*') || + (matrix.build == 'cp38-musllinux*') || (matrix.build == 'cp39-musllinux*') || + (matrix.build == 'cp310-musllinux*') || (matrix.build == 'cp311-musllinux*') || + (matrix.build == 'cp312-musllinux*') + uses: pypa/cibuildwheel@v2.16.2 + env: + CIBW_ARCHS_LINUX: ${{ matrix.arch }} + CIBW_BUILD: ${{ matrix.build }} + CIBW_BUILD_VERBOSITY: 1 + CIBW_ENVIRONMENT: CMAKE_BUILD_PARALLEL_LEVEL=2 NCNN_VULKAN=ON VULKAN_SDK=/project/Vulkan-Loader/build/Vulkan-Headers LD_LIBRARY_PATH=/project/Vulkan-Loader/build/loader + CIBW_BEFORE_ALL: apk add libxrandr-dev && + git clone https://github.com/KhronosGroup/Vulkan-Loader.git && + cd Vulkan-Loader && mkdir build && cd build && + ../scripts/update_deps.py && + cmake -DCMAKE_BUILD_TYPE=Release -DVULKAN_HEADERS_INSTALL_DIR=$(pwd)/Vulkan-Headers/build/install .. && + make -j$(nproc) && + cd Vulkan-Headers && + ln -s ../loader lib + with: + output-dir: wheelhouse - name: Show files run: ls -lh wheelhouse diff --git a/CMakeLists.txt b/CMakeLists.txt index 35a586ecda2..b6907207444 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,7 @@ option(NCNN_INSTALL_SDK "install ncnn library and headers" ON) option(NCNN_SIMPLEOCV "minimal opencv structure emulation" OFF) option(NCNN_SIMPLEOMP "minimal openmp runtime emulation" OFF) option(NCNN_SIMPLESTL "minimal cpp stl structure emulation" OFF) +option(NCNN_SIMPLEMATH "minimal cmath" OFF) option(NCNN_THREADS "build with threads" ON) option(NCNN_BENCHMARK "print benchmark information for every layer" OFF) option(NCNN_C_API "build with C api" ON) diff --git a/benchmark/README.md b/benchmark/README.md index 83347dec97b..211480893b8 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1631,7 +1631,92 @@ cooling_down = 1 vision_transformer min = 6605.19 max = 6606.66 avg = 6605.73 FastestDet min = 52.11 max = 52.97 avg = 52.61 ``` +### Raspberry Pi 5 Broadcom BCM2712, Cortex-A76 (ARMv8) (2.4GHz x 4) +``` +pi@raspberrypi:~/ncnn/benchmark $ ./benchncnn 10 4 0 -1 1 +loop_count = 10 +num_threads = 4 +powersave = 0 +gpu_device = -1 +cooling_down = 1 + squeezenet min = 8.56 max = 8.65 avg = 8.61 + squeezenet_int8 min = 11.65 max = 12.64 avg = 11.94 + mobilenet min = 11.32 max = 13.46 avg = 11.75 + mobilenet_int8 min = 11.30 max = 11.60 avg = 11.45 + mobilenet_v2 min = 13.57 max = 13.77 avg = 13.63 + mobilenet_v3 min = 9.18 max = 10.52 avg = 9.48 + shufflenet min = 4.56 max = 6.19 avg = 5.98 + shufflenet_v2 min = 5.04 max = 5.13 avg = 5.09 + mnasnet min = 8.27 max = 9.86 avg = 8.65 + proxylessnasnet min = 9.36 max = 11.18 avg = 9.62 + efficientnet_b0 min = 14.77 max = 14.96 avg = 14.87 + efficientnetv2_b0 min = 19.91 max = 20.11 avg = 19.99 + regnety_400m min = 11.91 max = 12.10 avg = 11.96 + blazeface min = 2.26 max = 2.29 avg = 2.28 + googlenet min = 32.80 max = 33.17 avg = 32.97 + googlenet_int8 min = 32.63 max = 32.99 avg = 32.78 + resnet18 min = 23.95 max = 24.21 avg = 24.12 + resnet18_int8 min = 32.50 max = 32.79 avg = 32.68 + alexnet min = 25.31 max = 25.75 avg = 25.51 + vgg16 min = 162.19 max = 165.08 avg = 163.75 + vgg16_int8 min = 187.46 max = 191.21 avg = 189.09 + resnet50 min = 55.95 max = 56.61 avg = 56.29 + resnet50_int8 min = 73.34 max = 73.97 avg = 73.59 + squeezenet_ssd min = 40.48 max = 41.39 avg = 40.92 + squeezenet_ssd_int8 min = 45.67 max = 46.35 avg = 46.06 + mobilenet_ssd min = 31.15 max = 31.73 avg = 31.48 + mobilenet_ssd_int8 min = 31.09 max = 31.44 avg = 31.27 + mobilenet_yolo min = 71.51 max = 72.38 avg = 71.95 + mobilenetv2_yolov3 min = 47.86 max = 48.41 avg = 48.04 + yolov4-tiny min = 55.95 max = 56.51 avg = 56.19 + nanodet_m min = 14.26 max = 14.68 avg = 14.48 + yolo-fastest-1.1 min = 6.48 max = 8.10 avg = 7.30 + yolo-fastestv2 min = 6.03 max = 7.33 avg = 7.04 + vision_transformer min = 613.62 max = 637.97 avg = 629.51 + FastestDet min = 6.53 max = 6.66 avg = 6.59 +pi@raspberrypi:~/ncnn/benchmark $ ./benchncnn 10 1 0 -1 1 +loop_count = 10 +num_threads = 1 +powersave = 0 +gpu_device = -1 +cooling_down = 1 + squeezenet min = 13.18 max = 13.27 avg = 13.22 + squeezenet_int8 min = 15.69 max = 15.93 avg = 15.78 + mobilenet min = 21.42 max = 21.55 avg = 21.46 + mobilenet_int8 min = 14.92 max = 20.91 avg = 17.34 + mobilenet_v2 min = 18.56 max = 23.06 avg = 19.24 + mobilenet_v3 min = 13.16 max = 13.33 avg = 13.25 + shufflenet min = 7.25 max = 11.14 avg = 8.43 + shufflenet_v2 min = 7.17 max = 11.15 avg = 7.70 + mnasnet min = 13.89 max = 13.94 avg = 13.91 + proxylessnasnet min = 17.01 max = 17.26 avg = 17.07 + efficientnet_b0 min = 26.19 max = 26.30 avg = 26.24 + efficientnetv2_b0 min = 39.69 max = 40.12 avg = 39.97 + regnety_400m min = 17.30 max = 17.44 avg = 17.36 + blazeface min = 4.74 max = 4.78 avg = 4.76 + googlenet min = 57.64 max = 57.84 avg = 57.72 + googlenet_int8 min = 55.80 max = 56.01 avg = 55.93 + resnet18 min = 31.90 max = 32.09 avg = 32.00 + resnet18_int8 min = 56.92 max = 57.16 avg = 57.01 + alexnet min = 39.84 max = 40.12 avg = 39.92 + vgg16 min = 208.33 max = 211.06 avg = 209.64 + vgg16_int8 min = 437.53 max = 440.55 avg = 439.35 + resnet50 min = 95.75 max = 96.68 avg = 96.28 + resnet50_int8 min = 116.80 max = 118.01 avg = 117.57 + squeezenet_ssd min = 47.75 max = 47.97 avg = 47.86 + squeezenet_ssd_int8 min = 61.98 max = 62.90 avg = 62.47 + mobilenet_ssd min = 52.83 max = 53.39 avg = 53.07 + mobilenet_ssd_int8 min = 46.15 max = 46.60 avg = 46.35 + mobilenet_yolo min = 117.68 max = 117.97 avg = 117.81 + mobilenetv2_yolov3 min = 67.37 max = 67.67 avg = 67.48 + yolov4-tiny min = 73.85 max = 74.35 avg = 74.10 + nanodet_m min = 22.78 max = 23.33 avg = 22.96 + yolo-fastest-1.1 min = 8.82 max = 8.91 avg = 8.87 + yolo-fastestv2 min = 8.18 max = 11.42 avg = 8.59 + vision_transformer min = 1267.90 max = 1269.45 avg = 1268.82 + FastestDet min = 7.79 max = 11.14 avg = 9.03 +``` ### Raspberry Pi Zero 2 W Broadcom BCM2710A1, Cortex-A53 (ARMv8) (1.0GHz x 4) ``` @@ -5079,6 +5164,61 @@ cooling_down = 0 mobilenetv2_yolov3 min = 3.69 max = 5.14 avg = 3.91 ``` +### nVIDIA RTX A3000 of Notebook (6GB) +``` +cx@HP-ZBook-Fury-15-6-inch-G8-Mobile-Workstation-PC:~/ncnn/build/benchmark$ ./benchncnn 10 1 0 1 +[0 Intel(R) UHD Graphics (TGL GT1)] queueC=0[1] queueG=0[1] queueT=0[1] +[0 Intel(R) UHD Graphics (TGL GT1)] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[0 Intel(R) UHD Graphics (TGL GT1)] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[0 Intel(R) UHD Graphics (TGL GT1)] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[0 Intel(R) UHD Graphics (TGL GT1)] fp16-matrix-16_8_8/16_8_16/16_16_16=0/0/0 +[1 NVIDIA RTX A3000 Laptop GPU] queueC=2[8] queueG=0[16] queueT=1[2] +[1 NVIDIA RTX A3000 Laptop GPU] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[1 NVIDIA RTX A3000 Laptop GPU] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[1 NVIDIA RTX A3000 Laptop GPU] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[1 NVIDIA RTX A3000 Laptop GPU] fp16-matrix-16_8_8/16_8_16/16_16_16=1/1/1 +loop_count = 10 +num_threads = 1 +powersave = 0 +gpu_device = 1 +cooling_down = 1 + squeezenet min = 1.49 max = 1.94 avg = 1.74 + squeezenet_int8 min = 6.13 max = 6.20 avg = 6.16 + mobilenet min = 4.05 max = 4.82 avg = 4.65 + mobilenet_int8 min = 10.24 max = 10.29 avg = 10.26 + mobilenet_v2 min = 0.98 max = 1.14 avg = 1.03 + mobilenet_v3 min = 1.74 max = 1.82 avg = 1.77 + shufflenet min = 1.43 max = 30.51 avg = 9.51 + shufflenet_v2 min = 3.43 max = 3.89 avg = 3.77 + mnasnet min = 6.50 max = 6.75 avg = 6.62 + proxylessnasnet min = 6.46 max = 7.28 avg = 7.00 + efficientnet_b0 min = 3.14 max = 15.11 avg = 7.29 + efficientnetv2_b0 min = 18.50 max = 20.13 avg = 19.17 + regnety_400m min = 2.16 max = 3.57 avg = 2.70 + blazeface min = 2.52 max = 2.76 avg = 2.65 + googlenet min = 2.67 max = 14.67 avg = 9.85 + googlenet_int8 min = 19.08 max = 19.40 avg = 19.19 + resnet18 min = 5.19 max = 9.44 avg = 8.48 + resnet18_int8 min = 16.57 max = 17.69 avg = 16.96 + alexnet min = 1.98 max = 3.24 avg = 2.23 + vgg16 min = 3.59 max = 12.34 avg = 10.99 + vgg16_int8 min = 110.63 max = 124.31 avg = 118.16 + resnet50 min = 3.01 max = 4.93 avg = 3.77 + resnet50_int8 min = 41.58 max = 44.80 avg = 43.24 + squeezenet_ssd min = 4.08 max = 4.70 avg = 4.32 + squeezenet_ssd_int8 min = 17.32 max = 17.92 avg = 17.46 + mobilenet_ssd min = 2.26 max = 8.23 avg = 5.57 + mobilenet_ssd_int8 min = 20.35 max = 21.89 avg = 20.76 + mobilenet_yolo min = 2.14 max = 16.94 avg = 6.44 + mobilenetv2_yolov3 min = 3.64 max = 5.09 avg = 4.02 + yolov4-tiny min = 10.94 max = 17.46 avg = 13.58 + nanodet_m min = 6.57 max = 13.91 avg = 9.82 + yolo-fastest-1.1 min = 5.40 max = 14.22 avg = 10.78 + yolo-fastestv2 min = 7.49 max = 9.43 avg = 7.99 + vision_transformer min = 76.04 max = 76.96 avg = 76.43 + FastestDet min = 6.31 max = 6.60 avg = 6.43 +``` + ### nVIDIA RTX2080 of Desktop ``` E:\projects\framework\ncnn\benchmark>benchncnn.exe 4096 1 0 0 0 @@ -6508,3 +6648,253 @@ cooling_down = 0 vision_transformer min = 600.83 max = 666.35 avg = 617.33 FastestDet min = 6.05 max = 6.72 avg = 6.23 ``` + +### AMD Ryzen 9 5950X 16-Core of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 -1 0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.68 max = 3.10 avg = 2.77 + squeezenet_int8 min = 3.57 max = 4.72 avg = 4.04 + mobilenet min = 3.09 max = 5.44 avg = 3.38 + mobilenet_int8 min = 2.36 max = 3.40 avg = 2.74 + mobilenet_v2 min = 4.24 max = 4.81 avg = 4.40 + mobilenet_v3 min = 3.46 max = 3.93 avg = 3.58 + shufflenet min = 3.21 max = 4.54 avg = 4.01 + shufflenet_v2 min = 2.99 max = 4.49 avg = 3.34 + mnasnet min = 3.62 max = 4.31 avg = 3.83 + proxylessnasnet min = 4.06 max = 5.70 avg = 4.23 + efficientnet_b0 min = 5.60 max = 6.55 avg = 5.81 + efficientnetv2_b0 min = 6.83 max = 8.82 avg = 7.12 + regnety_400m min = 8.02 max = 9.75 avg = 8.34 + blazeface min = 1.34 max = 1.77 avg = 1.46 + googlenet min = 11.62 max = 15.95 avg = 12.70 + googlenet_int8 min = 7.43 max = 10.06 avg = 7.92 + resnet18 min = 8.39 max = 10.39 avg = 9.04 + resnet18_int8 min = 6.23 max = 8.64 avg = 6.75 + alexnet min = 7.78 max = 12.51 avg = 8.51 + vgg16 min = 53.85 max = 63.39 avg = 56.36 + vgg16_int8 min = 35.61 max = 46.94 avg = 38.08 + resnet50 min = 18.55 max = 24.46 avg = 19.81 + resnet50_int8 min = 11.95 max = 23.21 avg = 13.51 + squeezenet_ssd min = 10.01 max = 13.16 avg = 10.69 + squeezenet_ssd_int8 min = 9.29 max = 14.02 avg = 10.47 + mobilenet_ssd min = 6.38 max = 10.26 avg = 7.15 + mobilenet_ssd_int8 min = 4.69 max = 6.98 avg = 5.42 + mobilenet_yolo min = 17.63 max = 22.59 avg = 19.45 + mobilenetv2_yolov3 min = 11.79 max = 15.67 avg = 12.76 + yolov4-tiny min = 21.53 max = 25.79 avg = 22.46 + nanodet_m min = 7.16 max = 9.99 avg = 8.01 + yolo-fastest-1.1 min = 3.66 max = 5.00 avg = 4.38 + yolo-fastestv2 min = 3.52 max = 5.20 avg = 4.60 + vision_transformer min = 67.01 max = 93.71 avg = 78.48 + FastestDet min = 4.44 max = 8.62 avg = 4.69 +``` + +### AMD Radeon RX 6900 XT of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 0 0 +[0 AMD Radeon RX 6900 XT] queueC=1[2] queueG=0[1] queueT=2[2] +[0 AMD Radeon RX 6900 XT] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[0 AMD Radeon RX 6900 XT] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[0 AMD Radeon RX 6900 XT] subgroup=64 basic/vote/ballot/shuffle=1/1/1/1 +[0 AMD Radeon RX 6900 XT] fp16-matrix-16_8_8/16_8_16/16_16_16=0/0/0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = 0 +cooling_down = 0 + squeezenet min = 2.19 max = 2.70 avg = 2.47 + squeezenet_int8 min = 3.94 max = 4.51 avg = 4.18 + mobilenet min = 2.03 max = 2.63 avg = 2.28 + mobilenet_int8 min = 2.56 max = 3.34 avg = 2.69 + mobilenet_v2 min = 2.29 max = 2.98 avg = 2.62 + mobilenet_v3 min = 2.31 max = 3.10 avg = 2.75 + shufflenet min = 1.89 max = 2.61 avg = 2.30 + shufflenet_v2 min = 2.17 max = 3.04 avg = 2.59 + mnasnet min = 2.19 max = 2.98 avg = 2.69 + proxylessnasnet min = 2.12 max = 4.08 avg = 2.62 + efficientnet_b0 min = 3.62 max = 5.27 avg = 4.21 + efficientnetv2_b0 min = 6.09 max = 7.15 avg = 6.49 + regnety_400m min = 2.55 max = 3.82 avg = 3.00 + blazeface min = 1.93 max = 2.56 avg = 2.28 + googlenet min = 3.35 max = 4.46 avg = 3.75 + googlenet_int8 min = 8.02 max = 12.84 avg = 9.15 + resnet18 min = 2.46 max = 3.14 avg = 2.84 + resnet18_int8 min = 6.37 max = 9.15 avg = 7.30 + alexnet min = 2.31 max = 2.91 avg = 2.69 + vgg16 min = 4.76 max = 5.79 avg = 5.24 + vgg16_int8 min = 35.94 max = 46.27 avg = 39.05 + resnet50 min = 3.25 max = 4.09 avg = 3.75 + resnet50_int8 min = 12.04 max = 20.53 avg = 14.61 + squeezenet_ssd min = 3.03 max = 5.31 avg = 3.66 + squeezenet_ssd_int8 min = 9.74 max = 13.46 avg = 10.42 + mobilenet_ssd min = 2.82 max = 4.75 avg = 3.39 + mobilenet_ssd_int8 min = 4.67 max = 6.76 avg = 5.30 + mobilenet_yolo min = 3.01 max = 3.67 avg = 3.34 + mobilenetv2_yolov3 min = 4.04 max = 6.46 avg = 4.55 + yolov4-tiny min = 5.75 max = 8.05 avg = 6.52 + nanodet_m min = 10.16 max = 14.97 avg = 13.11 + yolo-fastest-1.1 min = 2.36 max = 3.80 avg = 2.88 + yolo-fastestv2 min = 2.24 max = 3.19 avg = 2.80 + vision_transformer min = 20.43 max = 25.06 avg = 21.07 + FastestDet min = 2.49 max = 3.18 avg = 2.93 +``` + +### NVIDIA GeForce RTX 3060 Ti of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 0 0 +[0 NVIDIA GeForce RTX 3060 Ti] queueC=2[8] queueG=0[16] queueT=1[2] +[0 NVIDIA GeForce RTX 3060 Ti] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[0 NVIDIA GeForce RTX 3060 Ti] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[0 NVIDIA GeForce RTX 3060 Ti] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[0 NVIDIA GeForce RTX 3060 Ti] fp16-matrix-16_8_8/16_8_16/16_16_16=1/1/1 +[1 Intel(R) UHD Graphics 770] queueC=0[1] queueG=0[1] queueT=0[1] +[1 Intel(R) UHD Graphics 770] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[1 Intel(R) UHD Graphics 770] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[1 Intel(R) UHD Graphics 770] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[1 Intel(R) UHD Graphics 770] fp16-matrix-16_8_8/16_8_16/16_16_16=0/0/0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = 0 +cooling_down = 0 + squeezenet min = 0.80 max = 2.51 avg = 0.89 + squeezenet_int8 min = 2.81 max = 3.51 avg = 2.96 + mobilenet min = 0.70 max = 0.79 avg = 0.71 + mobilenet_int8 min = 2.95 max = 3.44 avg = 3.03 + mobilenet_v2 min = 1.09 max = 1.25 avg = 1.12 + mobilenet_v3 min = 1.33 max = 2.04 avg = 1.56 + shufflenet min = 1.20 max = 1.39 avg = 1.27 + shufflenet_v2 min = 1.50 max = 1.66 avg = 1.57 + mnasnet min = 1.11 max = 1.22 avg = 1.15 + proxylessnasnet min = 1.20 max = 1.63 avg = 1.24 + efficientnet_b0 min = 2.38 max = 3.21 avg = 2.61 + efficientnetv2_b0 min = 9.16 max = 11.35 avg = 9.63 + regnety_400m min = 1.86 max = 2.03 avg = 1.94 + blazeface min = 0.70 max = 1.10 avg = 0.76 + googlenet min = 2.11 max = 2.40 avg = 2.30 + googlenet_int8 min = 6.91 max = 7.88 avg = 7.17 + resnet18 min = 1.14 max = 1.47 avg = 1.19 + resnet18_int8 min = 4.96 max = 6.82 avg = 5.40 + alexnet min = 1.10 max = 1.85 avg = 1.19 + vgg16 min = 2.27 max = 3.97 avg = 2.46 + vgg16_int8 min = 19.02 max = 22.20 avg = 20.28 + resnet50 min = 2.00 max = 2.99 avg = 2.10 + resnet50_int8 min = 10.66 max = 13.30 avg = 11.29 + squeezenet_ssd min = 2.74 max = 3.44 avg = 2.90 + squeezenet_ssd_int8 min = 6.93 max = 7.95 avg = 7.19 + mobilenet_ssd min = 1.86 max = 2.07 avg = 1.96 + mobilenet_ssd_int8 min = 5.92 max = 6.48 avg = 6.09 + mobilenet_yolo min = 1.65 max = 2.58 avg = 1.78 + mobilenetv2_yolov3 min = 3.85 max = 4.11 avg = 3.96 + yolov4-tiny min = 6.54 max = 7.05 avg = 6.69 + nanodet_m min = 2.38 max = 3.28 avg = 2.72 + yolo-fastest-1.1 min = 1.73 max = 2.07 avg = 1.83 + yolo-fastestv2 min = 1.72 max = 1.92 avg = 1.80 + vision_transformer min = 53.91 max = 56.59 avg = 55.27 + FastestDet min = 1.48 max = 1.83 avg = 1.69 +``` + +### Intel(R) UHD Graphics 770 of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 1 0 +[0 NVIDIA GeForce RTX 3060 Ti] queueC=2[8] queueG=0[16] queueT=1[2] +[0 NVIDIA GeForce RTX 3060 Ti] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[0 NVIDIA GeForce RTX 3060 Ti] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[0 NVIDIA GeForce RTX 3060 Ti] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[0 NVIDIA GeForce RTX 3060 Ti] fp16-matrix-16_8_8/16_8_16/16_16_16=1/1/1 +[1 Intel(R) UHD Graphics 770] queueC=0[1] queueG=0[1] queueT=0[1] +[1 Intel(R) UHD Graphics 770] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[1 Intel(R) UHD Graphics 770] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[1 Intel(R) UHD Graphics 770] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[1 Intel(R) UHD Graphics 770] fp16-matrix-16_8_8/16_8_16/16_16_16=0/0/0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = 1 +cooling_down = 0 + squeezenet min = 3.11 max = 4.47 avg = 3.45 + squeezenet_int8 min = 1.89 max = 2.84 avg = 2.23 + mobilenet min = 4.98 max = 5.67 avg = 5.18 + mobilenet_int8 min = 2.54 max = 3.17 avg = 2.98 + mobilenet_v2 min = 4.03 max = 4.89 avg = 4.37 + mobilenet_v3 min = 4.45 max = 5.68 avg = 4.86 + shufflenet min = 3.42 max = 4.42 avg = 3.79 + shufflenet_v2 min = 3.00 max = 4.01 avg = 3.30 + mnasnet min = 4.21 max = 5.12 avg = 4.51 + proxylessnasnet min = 4.62 max = 5.64 avg = 4.90 + efficientnet_b0 min = 7.82 max = 8.63 avg = 8.10 + efficientnetv2_b0 min = 34.52 max = 36.34 avg = 35.29 + regnety_400m min = 6.07 max = 7.31 avg = 6.44 + blazeface min = 1.54 max = 1.67 avg = 1.59 + googlenet min = 11.53 max = 12.64 avg = 11.89 + googlenet_int8 min = 13.71 max = 15.52 avg = 14.38 + resnet18 min = 10.75 max = 12.94 avg = 11.07 + resnet18_int8 min = 9.04 max = 11.05 avg = 9.53 + alexnet min = 13.64 max = 14.37 avg = 13.98 + vgg16 min = 38.53 max = 40.16 avg = 39.22 + vgg16_int8 min = 16.04 max = 21.16 avg = 19.35 + resnet50 min = 25.61 max = 28.22 avg = 26.62 + resnet50_int8 min = 7.72 max = 12.83 avg = 10.29 + squeezenet_ssd min = 10.34 max = 15.88 avg = 14.75 + squeezenet_ssd_int8 min = 4.63 max = 7.13 avg = 5.66 + mobilenet_ssd min = 11.35 max = 13.06 avg = 12.44 + mobilenet_ssd_int8 min = 4.21 max = 6.31 avg = 5.32 + mobilenet_yolo min = 20.14 max = 22.92 avg = 21.94 + mobilenetv2_yolov3 min = 12.58 max = 14.88 avg = 14.21 + yolov4-tiny min = 20.62 max = 25.58 avg = 24.39 + nanodet_m min = 7.75 max = 12.49 avg = 11.42 + yolo-fastest-1.1 min = 3.68 max = 6.49 avg = 5.54 + yolo-fastestv2 min = 4.32 max = 5.39 avg = 4.51 + vision_transformer min = 796.51 max = 805.29 avg = 802.39 + FastestDet min = 2.89 max = 4.83 avg = 3.95 +``` + +### Intel® Core™ i7-13700K of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 -1 0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 1.69 max = 2.63 avg = 2.12 + squeezenet_int8 min = 1.83 max = 3.03 avg = 2.26 + mobilenet min = 1.69 max = 2.64 avg = 2.24 + mobilenet_int8 min = 2.47 max = 3.06 avg = 2.84 + mobilenet_v2 min = 1.94 max = 3.47 avg = 2.47 + mobilenet_v3 min = 1.49 max = 2.74 avg = 1.87 + shufflenet min = 1.57 max = 3.00 avg = 1.82 + shufflenet_v2 min = 1.41 max = 1.72 avg = 1.51 + mnasnet min = 1.73 max = 2.94 avg = 2.13 + proxylessnasnet min = 2.08 max = 3.31 avg = 2.69 + efficientnet_b0 min = 3.20 max = 4.99 avg = 3.78 + efficientnetv2_b0 min = 3.51 max = 5.16 avg = 4.08 + regnety_400m min = 4.51 max = 10.29 avg = 6.18 + blazeface min = 0.52 max = 0.92 avg = 0.59 + googlenet min = 5.49 max = 7.48 avg = 6.26 + googlenet_int8 min = 4.83 max = 7.54 avg = 5.90 + resnet18 min = 4.05 max = 6.61 avg = 4.83 + resnet18_int8 min = 3.77 max = 5.70 avg = 4.57 + alexnet min = 3.60 max = 5.09 avg = 4.26 + vgg16 min = 25.19 max = 28.79 avg = 26.81 + vgg16_int8 min = 17.52 max = 21.79 avg = 19.80 + resnet50 min = 9.23 max = 13.15 avg = 11.34 + resnet50_int8 min = 7.77 max = 12.00 avg = 10.18 + squeezenet_ssd min = 4.33 max = 6.73 avg = 4.96 + squeezenet_ssd_int8 min = 4.77 max = 7.62 avg = 5.71 + mobilenet_ssd min = 3.70 max = 6.43 avg = 4.53 + mobilenet_ssd_int8 min = 4.16 max = 6.53 avg = 5.38 + mobilenet_yolo min = 11.27 max = 14.93 avg = 12.90 + mobilenetv2_yolov3 min = 7.41 max = 11.52 avg = 9.11 + yolov4-tiny min = 12.05 max = 18.96 avg = 14.15 + nanodet_m min = 3.39 max = 5.77 avg = 4.07 + yolo-fastest-1.1 min = 1.95 max = 3.85 avg = 2.30 + yolo-fastestv2 min = 1.91 max = 3.52 avg = 2.27 + vision_transformer min = 79.50 max = 99.93 avg = 88.91 + FastestDet min = 1.92 max = 2.72 avg = 2.19 +``` \ No newline at end of file diff --git a/benchmark/benchncnn.cpp b/benchmark/benchncnn.cpp index 3155396fbaa..df2e8d37b94 100644 --- a/benchmark/benchncnn.cpp +++ b/benchmark/benchncnn.cpp @@ -25,7 +25,10 @@ #include "datareader.h" #include "net.h" #include "gpu.h" + +#ifndef NCNN_SIMPLESTL #include +#endif class DataReaderFromEmpty : public ncnn::DataReader { diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index f1fefc4a9ea..b2481e3d87a 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -915,6 +915,7 @@ This function is often used in conjunction with affine_grid() to build Spatial T | 0 | sample_type | int | 1 | | | 1 | padding_mode | int | 1 | | | 2 | align_corner | int | 0 | | +| 3 | permute_fusion| int | 0 | fuse with permute | Sample type: diff --git a/docs/how-to-build/how-to-build.md b/docs/how-to-build/how-to-build.md index 731895d995b..a11735a6a5d 100644 --- a/docs/how-to-build/how-to-build.md +++ b/docs/how-to-build/how-to-build.md @@ -601,39 +601,11 @@ Pick `build-XYZ/install` folder for further usage. ### Build for AllWinner D1 -Download c906 toolchain package from https://occ.t-head.cn/community/download?id=4046947553902661632 +Download c906 toolchain package from https://xuantie.t-head.cn/community/download?id=4224193099938729984 ```shell -tar -xf Xuantie-900-gcc-linux-5.10.4-glibc-x86_64-V2.2.6-20220516.tar.gz -export RISCV_ROOT_PATH=/home/nihui/osd/Xuantie-900-gcc-linux-5.10.4-glibc-x86_64-V2.2.6 -``` - -You need to fix riscv_vector.h header for workaround vfrec7/vfrsqrt7 bug. - -Open ```$RISCV_ROOT_PATH/lib/gcc/riscv64-unknown-linux-gnu/10.2.0/include/riscv_vector.h```, goto the file end, you will find three ```#endif```, and apply changes as the following -```c -#endif - -#define vfrec7_v_f32m1(x, vl) vfrdiv_vf_f32m1(x, 1.f, vl) -#define vfrec7_v_f32m2(x, vl) vfrdiv_vf_f32m2(x, 1.f, vl) -#define vfrec7_v_f32m4(x, vl) vfrdiv_vf_f32m4(x, 1.f, vl) -#define vfrec7_v_f32m8(x, vl) vfrdiv_vf_f32m8(x, 1.f, vl) -#define vfrec7_v_f16m1(x, vl) vfrdiv_vf_f16m1(x, 1.f, vl) -#define vfrec7_v_f16m2(x, vl) vfrdiv_vf_f16m2(x, 1.f, vl) -#define vfrec7_v_f16m4(x, vl) vfrdiv_vf_f16m4(x, 1.f, vl) -#define vfrec7_v_f16m8(x, vl) vfrdiv_vf_f16m8(x, 1.f, vl) - -#define vfrsqrt7_v_f32m1(x, vl) vfrdiv_vf_f32m1(vfsqrt_v_f32m1(x, vl), 1.f, vl) -#define vfrsqrt7_v_f32m2(x, vl) vfrdiv_vf_f32m2(vfsqrt_v_f32m2(x, vl), 1.f, vl) -#define vfrsqrt7_v_f32m4(x, vl) vfrdiv_vf_f32m4(vfsqrt_v_f32m4(x, vl), 1.f, vl) -#define vfrsqrt7_v_f32m8(x, vl) vfrdiv_vf_f32m8(vfsqrt_v_f32m8(x, vl), 1.f, vl) -#define vfrsqrt7_v_f16m1(x, vl) vfrdiv_vf_f16m1(vfsqrt_v_f16m1(x, vl), 1.f, vl) -#define vfrsqrt7_v_f16m2(x, vl) vfrdiv_vf_f16m2(vfsqrt_v_f16m2(x, vl), 1.f, vl) -#define vfrsqrt7_v_f16m4(x, vl) vfrdiv_vf_f16m4(vfsqrt_v_f16m4(x, vl), 1.f, vl) -#define vfrsqrt7_v_f16m8(x, vl) vfrdiv_vf_f16m8(vfsqrt_v_f16m8(x, vl), 1.f, vl) - -#endif -#endif +tar -xf Xuantie-900-gcc-linux-5.10.4-glibc-x86_64-V2.6.1-20220906.tar.gz +export RISCV_ROOT_PATH=/home/nihui/osd/Xuantie-900-gcc-linux-5.10.4-glibc-x86_64-V2.6.1 ``` Build ncnn with riscv-v vector and simpleocv enabled: diff --git a/python/README.md b/python/README.md index 7d6a5d2bb79..1424da4cf6a 100644 --- a/python/README.md +++ b/python/README.md @@ -33,7 +33,41 @@ If you want to build ncnn with some options not as default, or just like to buil * Visual Studio 2015 or higher * CMake >= 3.4 -## Build +## Build & Install + +1. clone ncnn and init submodule. + +```bash +cd /pathto/ncnn +git submodule init && git submodule update +``` + +2. build and install. + +``` +python setup.py install +``` + +if you want to enable the usage of vulkan, you can install as following: + +``` +python setup.py install --vulkan=on +``` + +> **Attention:** +> +> To enable Vulkan support, you must first install the Vulkan SDK. +> +> **For Windows or Linux Users:** +> +> Ensure that the `VULKAN_SDK` environment variable is set to the path of the Vulkan SDK. +> +> **For MacOS Users:** +> +> On MacOS, you will need to specify additional environment variables. For guidance on setting these variables, please refer to lines 279-286 in the following file: [ncnn/.github/workflows/release-python.yml at master · Tencent/ncnn](https://github.com/Tencent/ncnn/blob/master/.github/workflows/release-python.yml). + +## Custom-build & Install + 1. clone ncnn and init submodule. ```bash cd /pathto/ncnn @@ -47,7 +81,8 @@ cmake -DNCNN_PYTHON=ON .. make ``` -## Install +3. install + ```bash cd /pathto/ncnn pip install . @@ -60,6 +95,7 @@ python3 setup.py install ``` ## Tests + **test** ```bash cd /pathto/ncnn/python diff --git a/python/src/main.cpp b/python/src/main.cpp index 983c0aa2654..8fe1cbf82b3 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -185,6 +185,9 @@ PYBIND11_MODULE(ncnn, m) #endif // NCNN_VULKAN .def_readwrite("openmp_blocktime", &Option::openmp_blocktime) .def_readwrite("use_winograd_convolution", &Option::use_winograd_convolution) + .def_readwrite("use_winograd23_convolution", &Option::use_winograd23_convolution) + .def_readwrite("use_winograd43_convolution", &Option::use_winograd43_convolution) + .def_readwrite("use_winograd63_convolution", &Option::use_winograd63_convolution) .def_readwrite("use_sgemm_convolution", &Option::use_sgemm_convolution) .def_readwrite("use_int8_inference", &Option::use_int8_inference) .def_readwrite("use_vulkan_compute", &Option::use_vulkan_compute) diff --git a/setup.py b/setup.py index 2bde78bc378..92b62453f24 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ from setuptools import setup, find_packages, Extension from setuptools.command.build_ext import build_ext +from setuptools.command.install import install def find_version(): @@ -22,6 +23,41 @@ def find_version(): return version_major[0] + "." + version_minor[0] + "." + ncnn_version raise RuntimeError("Unable to find version string.") +# Parse environment variables +NCNN_VULKAN = os.environ.get("NCNN_VULKAN", "") +Vulkan_INCLUDE_DIR = os.environ.get("Vulkan_INCLUDE_DIR", "") +Vulkan_LIBRARY = os.environ.get("Vulkan_LIBRARY", "") +VULKAN_SDK = os.environ.get("VULKAN_SDK", "") +CMAKE_TOOLCHAIN_FILE = os.environ.get("CMAKE_TOOLCHAIN_FILE", "") +PLATFORM = os.environ.get("PLATFORM", "") +ARCHS = os.environ.get("ARCHS", "") +DEPLOYMENT_TARGET = os.environ.get("DEPLOYMENT_TARGET", "") +OpenMP_C_FLAGS = os.environ.get("OpenMP_C_FLAGS", "") +OpenMP_CXX_FLAGS = os.environ.get("OpenMP_CXX_FLAGS", "") +OpenMP_C_LIB_NAMES = os.environ.get("OpenMP_C_LIB_NAMES", "") +OpenMP_CXX_LIB_NAMES = os.environ.get("OpenMP_CXX_LIB_NAMES", "") +OpenMP_libomp_LIBRARY = os.environ.get("OpenMP_libomp_LIBRARY", "") +ENABLE_BITCODE = os.environ.get("ENABLE_BITCODE", "") +ENABLE_ARC = os.environ.get("ENABLE_ARC", "") +ENABLE_VISIBILITY = os.environ.get("ENABLE_VISIBILITY", "") + +# Parse variables from command line with setup.py install +class InstallCommand(install): + user_options = install.user_options + [ + ('vulkan=', None, 'Enable the usage of Vulkan.'), + ] + def initialize_options(self): + install.initialize_options(self) + self.vulkan = None + + def finalize_options(self): + install.finalize_options(self) + + def run(self): + global NCNN_VULKAN + if self.vulkan == 'on' or self.vulkan == "ON": + NCNN_VULKAN = "ON" + install.run(self) # Convert distutils Windows platform specifiers to CMake -A arguments PLAT_TO_CMAKE = { @@ -70,6 +106,39 @@ def build_extension(self, ext): "-DNCNN_BUILD_EXAMPLES=OFF", "-DNCNN_BUILD_TOOLS=OFF", ] + if NCNN_VULKAN != "": + cmake_args.append("-DNCNN_VULKAN=" + NCNN_VULKAN) + if Vulkan_INCLUDE_DIR != "": + cmake_args.append("-DVulkan_INCLUDE_DIR=" + Vulkan_INCLUDE_DIR) + if Vulkan_LIBRARY != "": + cmake_args.append("-DVulkan_LIBRARY=" + Vulkan_LIBRARY) + if VULKAN_SDK != "": + cmake_args.append("-DVULKAN_SDK=" + VULKAN_SDK) + if CMAKE_TOOLCHAIN_FILE != "": + cmake_args.append("-DCMAKE_TOOLCHAIN_FILE=" + CMAKE_TOOLCHAIN_FILE) + if PLATFORM != "": + cmake_args.append("-DPLATFORM=" + PLATFORM) + if ARCHS != "": + cmake_args.append("-DARCHS=" + ARCHS) + if DEPLOYMENT_TARGET != "": + cmake_args.append("-DDEPLOYMENT_TARGET=" + DEPLOYMENT_TARGET) + if OpenMP_C_FLAGS != "": + cmake_args.append("-DOpenMP_C_FLAGS=" + OpenMP_C_FLAGS) + if OpenMP_CXX_FLAGS != "": + cmake_args.append("-DOpenMP_CXX_FLAGS=" + OpenMP_CXX_FLAGS) + if OpenMP_C_LIB_NAMES != "": + cmake_args.append("-DOpenMP_C_LIB_NAMES=" + OpenMP_C_LIB_NAMES) + if OpenMP_CXX_LIB_NAMES != "": + cmake_args.append("-DOpenMP_CXX_LIB_NAMES=" + OpenMP_CXX_LIB_NAMES) + if OpenMP_libomp_LIBRARY != "": + cmake_args.append("-DOpenMP_libomp_LIBRARY=" + OpenMP_libomp_LIBRARY) + if ENABLE_BITCODE != "": + cmake_args.append("-DENABLE_BITCODE=" + ENABLE_BITCODE) + if ENABLE_ARC != "": + cmake_args.append("-DENABLE_ARC=" + ENABLE_ARC) + if ENABLE_VISIBILITY != "": + cmake_args.append("-DENABLE_VISIBILITY=" + ENABLE_VISIBILITY) + build_args = [] if self.compiler.compiler_type == "msvc": @@ -150,5 +219,5 @@ def build_extension(self, ext): package_dir={"": "python"}, install_requires=requirements, ext_modules=[CMakeExtension("ncnn")], - cmdclass={"build_ext": CMakeBuild}, + cmdclass={'install': InstallCommand, "build_ext": CMakeBuild}, ) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c67426ff210..a664e2da929 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -39,6 +39,7 @@ set(ncnn_SRCS simpleocv.cpp simpleomp.cpp simplestl.cpp + simplemath.cpp ) if(ANDROID) @@ -208,7 +209,7 @@ if(NOT NCNN_SHARED_LIB) set_target_properties(ncnn PROPERTIES COMPILE_FLAGS -DNCNN_STATIC_DEFINE) endif() -if(NCNN_SIMPLESTL) +if(NCNN_SIMPLESTL AND NOT NCNN_SIMPLEMATH) # link math lib explicitly target_link_libraries(ncnn PUBLIC m) endif() @@ -261,7 +262,6 @@ if(NCNN_THREADS) if(TARGET Threads::Threads) target_link_libraries(ncnn PUBLIC Threads::Threads) endif() - if(NCNN_SIMPLEOMP OR NCNN_SIMPLESTL) target_link_libraries(ncnn PUBLIC pthread) endif() @@ -581,6 +581,7 @@ if(NCNN_INSTALL_SDK) simpleocv.h simpleomp.h simplestl.h + simplemath.h vulkan_header_fix.h ${CMAKE_CURRENT_BINARY_DIR}/ncnn_export.h ${CMAKE_CURRENT_BINARY_DIR}/layer_shader_type_enum.h @@ -599,5 +600,4 @@ endif() # add ncnn and generate-spirv to a virtual project group set_property(GLOBAL PROPERTY USE_FOLDERS ON) set_property(TARGET ncnn PROPERTY FOLDER "libncnn") -set_property(TARGET ncnn-generate-spirv PROPERTY FOLDER "libncnn") - +set_property(TARGET ncnn-generate-spirv PROPERTY FOLDER "libncnn") \ No newline at end of file diff --git a/src/gpu.cpp b/src/gpu.cpp index f32f6e20a67..72ca65bc620 100644 --- a/src/gpu.cpp +++ b/src/gpu.cpp @@ -16,7 +16,6 @@ #if NCNN_VULKAN -#include #include #include diff --git a/src/layer.cpp b/src/layer.cpp index a4f73a5c082..562576a5493 100644 --- a/src/layer.cpp +++ b/src/layer.cpp @@ -16,7 +16,6 @@ #include "cpu.h" -#include #include #ifdef _MSC_VER diff --git a/src/layer.h b/src/layer.h index ae4a8430d84..f0418a9ffcd 100644 --- a/src/layer.h +++ b/src/layer.h @@ -21,8 +21,6 @@ #include "paramdict.h" #include "platform.h" -#include - #if NCNN_VULKAN #include "command.h" #include "pipeline.h" diff --git a/src/layer/arm/binaryop_arm.cpp b/src/layer/arm/binaryop_arm.cpp index 25bfeb55557..55fb165911e 100644 --- a/src/layer/arm/binaryop_arm.cpp +++ b/src/layer/arm/binaryop_arm.cpp @@ -14,8 +14,6 @@ #include "binaryop_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/binaryop_arm_asimdhp.cpp b/src/layer/arm/binaryop_arm_asimdhp.cpp index 9d4e9b94f7c..b9a8ea2d00b 100644 --- a/src/layer/arm/binaryop_arm_asimdhp.cpp +++ b/src/layer/arm/binaryop_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "binaryop_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/cast_arm_bf16.cpp b/src/layer/arm/cast_arm_bf16.cpp index aaaec09f968..358b9a9d2af 100644 --- a/src/layer/arm/cast_arm_bf16.cpp +++ b/src/layer/arm/cast_arm_bf16.cpp @@ -14,7 +14,7 @@ #include "cpu.h" #include "mat.h" -#include + namespace ncnn { #include "cast_bf16.h" diff --git a/src/layer/arm/convolution_3x3_int8.h b/src/layer/arm/convolution_3x3_int8.h index 826ed8a82e0..1868b5d6855 100644 --- a/src/layer/arm/convolution_3x3_int8.h +++ b/src/layer/arm/convolution_3x3_int8.h @@ -12,235 +12,6 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv3x3s1_winograd43_transform_kernel_int8_neon(const Mat& kernel, Mat& kernel_tm_packed, int inch, int outch, const Option& opt) -{ - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 8a-8b-inch/8a-36-outch/8b -#if __ARM_NEON - if (outch >= 8) - { - kernel_tm_packed.create(inch, 36, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)2u * 8, 8); - } - else if (outch >= 4) - { - kernel_tm_packed.create(inch, 36, outch / 4 + outch % 4, (size_t)2u * 4, 4); - } -#else // __ARM_NEON - if (outch >= 2) - { - kernel_tm_packed.create(inch, 36, outch / 2 + outch % 2, (size_t)2u * 2, 2); - } -#endif // __ARM_NEON - else - { - kernel_tm_packed.create(inch, 36, outch, (size_t)2u, 1); - } - - int p = 0; -#if __ARM_NEON - for (; p + 7 < outch; p += 8) - { - Mat g0 = kernel_tm_packed.channel(p / 8); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q < inch; q++) - { - for (int i = 0; i < 8; i++) - { - g00[0] = kernel_tm.channel(p + i).row(q)[k]; - g00++; - } - } - } - } - for (; p + 3 < outch; p += 4) - { - const Mat k0 = kernel_tm.channel(p); - const Mat k1 = kernel_tm.channel(p + 1); - const Mat k2 = kernel_tm.channel(p + 2); - const Mat k3 = kernel_tm.channel(p + 3); - - Mat g0 = kernel_tm_packed.channel(p / 8 + (p % 8) / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q < inch; q++) - { - g00[0] = k0.row(q)[k]; - g00[1] = k1.row(q)[k]; - g00[2] = k2.row(q)[k]; - g00[3] = k3.row(q)[k]; - g00 += 4; - } - } - } -#else // __ARM_NEON - for (; p + 1 < outch; p += 2) - { - const Mat k0 = kernel_tm.channel(p); - const Mat k1 = kernel_tm.channel(p + 1); - - Mat g0 = kernel_tm_packed.channel(p / 2); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - int q = 0; -#if __ARM_FEATURE_SIMD32 - for (; q + 1 < inch; q += 2) - { - g00[0] = k0.row(q)[k]; - g00[2] = k1.row(q)[k]; - g00[1] = k0.row(q + 1)[k]; - g00[3] = k1.row(q + 1)[k]; - g00 += 4; - } -#endif // __ARM_FEATURE_SIMD32 - for (; q < inch; q++) - { - g00[0] = k0.row(q)[k]; - g00[1] = k1.row(q)[k]; - g00 += 2; - } - } - } -#endif // __ARM_NEON - for (; p < outch; p++) - { - const Mat k0 = kernel_tm.channel(p); - -#if __ARM_NEON - Mat g0 = kernel_tm_packed.channel(p / 8 + (p % 8) / 4 + p % 4); -#else - Mat g0 = kernel_tm_packed.channel(p / 2 + p % 2); -#endif - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q < inch; q++) - { - g00[0] = k0.row(q)[k]; - g00 += 1; - } - } - } -} - -static void conv3x3s1_winograd43_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tiles = outw / 4; - int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - conv3x3s1_winograd43_transform_input_int8_neon(bottom_blob_bordered, bottom_blob_tm, opt); - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - convolution_winograd_dot_int8_neon(bottom_blob_tm, outch, kernel_tm, top_blob_tm, opt); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); - } - { - conv3x3s1_winograd43_transform_output_int8_neon(top_blob_tm, top_blob_bordered, opt); - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} - static void conv3x3s2_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch) { kernel_tm.create(8 * 9, inch, outch / 8 + outch % 8, (size_t)1u); diff --git a/src/layer/arm/convolution_3x3_pack8to1_int8.h b/src/layer/arm/convolution_3x3_pack8to1_int8.h deleted file mode 100644 index 5af9f5938e1..00000000000 --- a/src/layer/arm/convolution_3x3_pack8to1_int8.h +++ /dev/null @@ -1,185 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_neon(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt) -{ - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 8a-inch/8a-36-outch - kernel_tm_pack8to1.create(8 * inch / 8, 36, outch / 8 + outch % 8, (size_t)2u * 8, 8); - - int p = 0; - for (; p + 7 < outch; p += 8) - { - const Mat k0 = kernel_tm.channel(p); - const Mat k1 = kernel_tm.channel(p + 1); - const Mat k2 = kernel_tm.channel(p + 2); - const Mat k3 = kernel_tm.channel(p + 3); - const Mat k4 = kernel_tm.channel(p + 4); - const Mat k5 = kernel_tm.channel(p + 5); - const Mat k6 = kernel_tm.channel(p + 6); - const Mat k7 = kernel_tm.channel(p + 7); - - Mat g0 = kernel_tm_pack8to1.channel(p / 8); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 8; i++) - { - g00[0] = k0.row(q + i)[k]; - g00[1] = k1.row(q + i)[k]; - g00[2] = k2.row(q + i)[k]; - g00[3] = k3.row(q + i)[k]; - g00[4] = k4.row(q + i)[k]; - g00[5] = k5.row(q + i)[k]; - g00[6] = k6.row(q + i)[k]; - g00[7] = k7.row(q + i)[k]; - - g00 += 8; - } - } - } - } - for (; p < outch; p++) - { - const Mat k0 = kernel_tm.channel(p); - - Mat g0 = kernel_tm_pack8to1.channel(p / 8 + p % 8); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 8; i++) - { - g00[0] = k0.row(q + i)[k]; - - g00 += 1; - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to1_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tiles = outw / 4; - int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - conv3x3s1_winograd43_transform_input_pack8_int8_neon(bottom_blob_bordered, bottom_blob_tm, opt); - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - convolution_winograd_dot_pack8to1_int8_neon(bottom_blob_tm, outch, kernel_tm, top_blob_tm, opt); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); - } - { - conv3x3s1_winograd43_transform_output_int8_neon(top_blob_tm, top_blob_bordered, opt); - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/arm/convolution_3x3_pack8to4_int8.h b/src/layer/arm/convolution_3x3_pack8to4_int8.h deleted file mode 100644 index ee67ba61ef7..00000000000 --- a/src/layer/arm/convolution_3x3_pack8to4_int8.h +++ /dev/null @@ -1,205 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_neon(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt) -{ - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 4b-8a-inch/8a-36-outch/4b - kernel_tm_pack8.create(inch / 8, 36, outch / 8 + (outch % 8) / 4, (size_t)2u * 64, 64); - - int q = 0; - for (; q + 7 < outch; q += 8) - { - const Mat k0 = kernel_tm.channel(q); - const Mat k1 = kernel_tm.channel(q + 1); - const Mat k2 = kernel_tm.channel(q + 2); - const Mat k3 = kernel_tm.channel(q + 3); - const Mat k4 = kernel_tm.channel(q + 4); - const Mat k5 = kernel_tm.channel(q + 5); - const Mat k6 = kernel_tm.channel(q + 6); - const Mat k7 = kernel_tm.channel(q + 7); - - Mat kernel_tm = kernel_tm_pack8.channel(q / 8); - - for (int k = 0; k < 36; k++) - { - short* g00 = kernel_tm.row(k); - - for (int p = 0; p + 7 < inch; p += 8) - { - for (int i = 0; i < 8; i++) - { - const short* k00 = k0.row(p + i); - const short* k10 = k1.row(p + i); - const short* k20 = k2.row(p + i); - const short* k30 = k3.row(p + i); - const short* k40 = k4.row(p + i); - const short* k50 = k5.row(p + i); - const short* k60 = k6.row(p + i); - const short* k70 = k7.row(p + i); - - g00[0] = k00[k]; - g00[1] = k10[k]; - g00[2] = k20[k]; - g00[3] = k30[k]; - g00[4] = k40[k]; - g00[5] = k50[k]; - g00[6] = k60[k]; - g00[7] = k70[k]; - - g00 += 8; - } - } - } - } - for (; q + 3 < outch; q += 4) - { - const Mat k0 = kernel_tm.channel(q); - const Mat k1 = kernel_tm.channel(q + 1); - const Mat k2 = kernel_tm.channel(q + 2); - const Mat k3 = kernel_tm.channel(q + 3); - - Mat kernel_tm = kernel_tm_pack8.channel(q / 8 + (q % 8) / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = kernel_tm.row(k); - - for (int p = 0; p + 7 < inch; p += 8) - { - for (int i = 0; i < 8; i++) - { - const short* k00 = k0.row(p + i); - const short* k10 = k1.row(p + i); - const short* k20 = k2.row(p + i); - const short* k30 = k3.row(p + i); - - g00[0] = k00[k]; - g00[1] = k10[k]; - g00[2] = k20[k]; - g00[3] = k30[k]; - - g00 += 4; - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tiles = outw / 4; - int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - conv3x3s1_winograd43_transform_input_pack8_int8_neon(bottom_blob_bordered, bottom_blob_tm, opt); - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - convolution_winograd_dot_pack8to4_int8_neon(bottom_blob_tm, outch, kernel_tm, top_blob_tm, opt); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u * 4, 4, opt.workspace_allocator); - } - { - conv3x3s1_winograd43_transform_output_pack4_int8_neon(top_blob_tm, top_blob_bordered, opt); - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/arm/convolution_3x3_winograd_int8.h b/src/layer/arm/convolution_3x3_winograd_int8.h new file mode 100644 index 00000000000..ab108b3f089 --- /dev/null +++ b/src/layer/arm/convolution_3x3_winograd_int8.h @@ -0,0 +1,5719 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void pack_A_tile_int8(const Mat& A, Mat& AT, int batch, int max_ii, int max_kk) +{ + const int N = max_kk * batch; + + for (int b = 0; b < batch; b++) + { + short* pp = AT.row(b); + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + pp[4] = p0[N * 4]; + pp[5] = p0[N * 5]; + pp[6] = p0[N * 6]; + pp[7] = p0[N * 7]; + p0 += batch; + pp += 8; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + p0 += batch; + pp += 4; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[batch + N]; + p0 += batch * 2; + pp += 4; + } +#endif + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + p0 += batch; + pp += 2; + } + } + for (; ii < max_ii; ii++) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += batch; + pp += 1; + } + } + } +} + +static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int batch, int max_jj, int max_kk, int nT) +{ + // NCNN_LOGE("transpose_pack_B_tile_int8 %d %d", max_jj, max_kk); + + #pragma omp parallel for num_threads(nT) + for (int b = 0; b < batch; b++) + { + short* pp = BT.row(b); + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + // transpose 8x12 +#if NCNN_GNU_INLINE_ASM + asm volatile( + "prfm pldl1keep, [%0, #512] \n" + "prfm pldl1keep, [%0, #1024] \n" + "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" + "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0], #64 \n" + "ld4 {v16.8h, v17.8h, v18.8h, v19.8h}, [%0] \n" + "sub %0, %0, #128 \n" + "uzp1 v20.8h, v0.8h, v4.8h \n" + "uzp2 v26.8h, v0.8h, v4.8h \n" + "uzp1 v23.8h, v2.8h, v6.8h \n" + "uzp2 v29.8h, v2.8h, v6.8h \n" + "uzp1 v21.8h, v16.8h, v1.8h \n" + "uzp2 v27.8h, v16.8h, v1.8h \n" + "uzp1 v22.8h, v5.8h, v17.8h \n" + "uzp2 v28.8h, v5.8h, v17.8h \n" + "uzp1 v24.8h, v18.8h, v3.8h \n" + "uzp2 v30.8h, v18.8h, v3.8h \n" + "uzp1 v25.8h, v7.8h, v19.8h \n" + "uzp2 v31.8h, v7.8h, v19.8h \n" + "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" + "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [%1], #64 \n" + "st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [%1], #64 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + p0 += max_jj * batch * 8; +#else // NCNN_GNU_INLINE_ASM + int16x8x4_t _r0 = vld4q_s16(p0); + int16x8x4_t _r1 = vld4q_s16(p0 + 32); + int16x8x4_t _r2 = vld4q_s16(p0 + 64); + int16x8x2_t _t0 = vuzpq_s16(_r0.val[0], _r1.val[0]); + int16x8x2_t _t1 = vuzpq_s16(_r2.val[0], _r0.val[1]); + int16x8x2_t _t2 = vuzpq_s16(_r1.val[1], _r2.val[1]); + int16x8x2_t _t3 = vuzpq_s16(_r0.val[2], _r1.val[2]); + int16x8x2_t _t4 = vuzpq_s16(_r2.val[2], _r0.val[3]); + int16x8x2_t _t5 = vuzpq_s16(_r1.val[3], _r2.val[3]); + vst1q_s16(pp, _t0.val[0]); + vst1q_s16(pp + 8, _t1.val[0]); + vst1q_s16(pp + 16, _t2.val[0]); + vst1q_s16(pp + 24, _t3.val[0]); + vst1q_s16(pp + 32, _t4.val[0]); + vst1q_s16(pp + 40, _t5.val[0]); + vst1q_s16(pp + 48, _t0.val[1]); + vst1q_s16(pp + 56, _t1.val[1]); + vst1q_s16(pp + 64, _t2.val[1]); + vst1q_s16(pp + 72, _t3.val[1]); + vst1q_s16(pp + 80, _t4.val[1]); + vst1q_s16(pp + 88, _t5.val[1]); + p0 += max_jj * batch * 8; + pp += 96; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + int16x8x2_t _r01 = vld2q_s16(p0); + int16x4x2_t _r2 = vld2_s16(p0 + 16); + vst1q_s16(pp, _r01.val[0]); + vst1_s16(pp + 8, _r2.val[0]); + vst1q_s16(pp + 12, _r01.val[1]); + vst1_s16(pp + 20, _r2.val[1]); + p0 += max_jj * batch * 2; + pp += 24; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + int16x8_t _r0 = vld1q_s16(p0); + int16x4_t _r1 = vld1_s16(p0 + 8); + vst1q_s16(pp, _r0); + vst1_s16(pp + 8, _r1); + p0 += max_jj * batch; + pp += 12; + } + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + // transpose 8x8 +#if NCNN_GNU_INLINE_ASM + asm volatile( + "prfm pldl1keep, [%0, #512] \n" + "prfm pldl1keep, [%0, #1024] \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" + "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0] \n" + "sub %0, %0, #64 \n" + "zip1 v16.8h, v0.8h, v4.8h \n" + "zip2 v20.8h, v0.8h, v4.8h \n" + "zip1 v17.8h, v1.8h, v5.8h \n" + "zip2 v21.8h, v1.8h, v5.8h \n" + "zip1 v18.8h, v2.8h, v6.8h \n" + "zip2 v22.8h, v2.8h, v6.8h \n" + "zip1 v19.8h, v3.8h, v7.8h \n" + "zip2 v23.8h, v3.8h, v7.8h \n" + "st4 {v16.8h, v17.8h, v18.8h, v19.8h}, [%1], #64 \n" + "st4 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + p0 += max_jj * batch * 8; +#else // NCNN_GNU_INLINE_ASM + int16x8_t _r0 = vld1q_s16(p0); + int16x8_t _r1 = vld1q_s16(p0 + 8); + int16x8_t _r2 = vld1q_s16(p0 + 16); + int16x8_t _r3 = vld1q_s16(p0 + 24); + int16x8_t _r4 = vld1q_s16(p0 + 32); + int16x8_t _r5 = vld1q_s16(p0 + 40); + int16x8_t _r6 = vld1q_s16(p0 + 48); + int16x8_t _r7 = vld1q_s16(p0 + 56); + int16x8x2_t _r04 = vzipq_s16(_r0, _r4); + int16x8x2_t _r15 = vzipq_s16(_r1, _r5); + int16x8x2_t _r26 = vzipq_s16(_r2, _r6); + int16x8x2_t _r37 = vzipq_s16(_r3, _r7); + int16x8x4_t _r0123; + _r0123.val[0] = _r04.val[0]; + _r0123.val[1] = _r15.val[0]; + _r0123.val[2] = _r26.val[0]; + _r0123.val[3] = _r37.val[0]; + int16x8x4_t _r4567; + _r4567.val[0] = _r04.val[1]; + _r4567.val[1] = _r15.val[1]; + _r4567.val[2] = _r26.val[1]; + _r4567.val[3] = _r37.val[1]; + vst4q_s16(pp, _r0123); + vst4q_s16(pp + 32, _r4567); + p0 += max_jj * batch * 8; + pp += 64; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + int16x8x2_t _r01 = vld2q_s16(p0); + vst1q_s16(pp, _r01.val[0]); + vst1q_s16(pp + 8, _r01.val[1]); + p0 += max_jj * batch * 2; + pp += 16; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + int16x8_t _r0 = vld1q_s16(p0); + vst1q_s16(pp, _r0); + p0 += max_jj * batch; + pp += 8; + } + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #768] \n" + "ld1 {v0.8h, v1.8h, v2.8h}, [%0], #48 \n" + "ld1 {v3.8h, v4.8h, v5.8h}, [%0] \n" + "sub %0, %0, #48 \n" + "zip1 v16.8h, v0.8h, v3.8h \n" + "zip2 v20.8h, v0.8h, v3.8h \n" + "zip1 v17.8h, v1.8h, v4.8h \n" + "zip2 v21.8h, v1.8h, v4.8h \n" + "zip1 v18.8h, v2.8h, v5.8h \n" + "zip2 v22.8h, v2.8h, v5.8h \n" + "st3 {v16.8h, v17.8h, v18.8h}, [%1], #48 \n" + "st3 {v20.8h, v21.8h, v22.8h}, [%1], #48 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v16", "v17", "v18", "v20", "v21", "v22"); + p0 += max_jj * batch * 8; +#else // __aarch64__ + asm volatile( + "pld [%0, #768] \n" + "vldm %0, {d0-d11} \n" + "vzip.16 q0, q3 \n" + "vzip.16 q1, q4 \n" + "vzip.16 q2, q5 \n" + "vst3.s16 {d0,d2,d4}, [%1]! \n" + "vst3.s16 {d1,d3,d5}, [%1]! \n" + "vst3.s16 {d6,d8,d10}, [%1]! \n" + "vst3.s16 {d7,d9,d11}, [%1]! \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); + p0 += max_jj * batch * 8; +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int16x8_t _r0 = vld1q_s16(p0); + int16x8_t _r1 = vld1q_s16(p0 + 8); + int16x8_t _r2 = vld1q_s16(p0 + 16); + int16x8_t _r3 = vld1q_s16(p0 + 24); + int16x8_t _r4 = vld1q_s16(p0 + 32); + int16x8_t _r5 = vld1q_s16(p0 + 40); + int16x8x2_t _r03 = vzipq_s16(_r0, _r3); + int16x8x2_t _r14 = vzipq_s16(_r1, _r4); + int16x8x2_t _r25 = vzipq_s16(_r2, _r5); + int16x8x3_t _r012; + _r012.val[0] = _r03.val[0]; + _r012.val[1] = _r14.val[0]; + _r012.val[2] = _r25.val[0]; + int16x8x3_t _r345; + _r345.val[0] = _r03.val[1]; + _r345.val[1] = _r14.val[1]; + _r345.val[2] = _r25.val[1]; + vst3q_s16(pp, _r012); + vst3q_s16(pp + 24, _r345); + p0 += max_jj * batch * 8; + pp += 48; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + int16x8x2_t _r01 = vld2q_s16(p0); + int32x4x2_t _r01x = vtrnq_s32(vreinterpretq_s32_s16(_r01.val[0]), vreinterpretq_s32_s16(_r01.val[1])); + int32x2x3_t _r012; + _r012.val[0] = vget_low_s32(_r01x.val[0]); + _r012.val[1] = vget_low_s32(_r01x.val[1]); + _r012.val[2] = vget_high_s32(_r01x.val[0]); + vst3_s32((int*)pp, _r012); + p0 += max_jj * batch * 2; + pp += 12; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + int16x4_t _r0 = vld1_s16(p0); + vst1_s16(pp, _r0); + pp[4] = p0[4]; + pp[5] = p0[5]; + p0 += max_jj * batch; + pp += 6; + } + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #512] \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n" + "st4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1", "v2", "v3"); + p0 += max_jj * batch * 8; +#else // __aarch64__ + asm volatile( + "pld [%0, #512] \n" + "vldm %0, {d0-d7} \n" + "vst4.s16 {d0,d2,d4,d6}, [%1]! \n" + "vst4.s16 {d1,d3,d5,d7}, [%1]! \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "q0", "q1", "q2", "q3"); + p0 += max_jj * batch * 8; +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int16x8x4_t _r0123; + _r0123.val[0] = vld1q_s16(p0); + _r0123.val[1] = vld1q_s16(p0 + 8); + _r0123.val[2] = vld1q_s16(p0 + 16); + _r0123.val[3] = vld1q_s16(p0 + 24); + vst4q_s16(pp, _r0123); + p0 += max_jj * batch * 8; + pp += 32; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + int16x4x2_t _r01 = vld2_s16(p0); + vst1_s16(pp, _r01.val[0]); + vst1_s16(pp + 4, _r01.val[1]); + p0 += max_jj * batch * 2; + pp += 8; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + int16x4_t _r0 = vld1_s16(p0); + vst1_s16(pp, _r0); + p0 += max_jj * batch; + pp += 4; + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const short* p0 = B; + + int kk = 0; +#if __ARM_NEON + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #256] \n" + "ld1 {v0.8h, v1.8h}, [%0] \n" + "st2 {v0.8h, v1.8h}, [%1], #32 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1"); + p0 += max_jj * batch * 8; +#else // __aarch64__ + asm volatile( + "pld [%0, #256] \n" + "vld1.s16 {d0-d3}, [%0] \n" + "vst2.s16 {d0-d3}, [%1]! \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "q0", "q1"); + p0 += max_jj * batch * 8; +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int16x8x2_t _r01; + _r01.val[0] = vld1q_s16(p0); + _r01.val[1] = vld1q_s16(p0 + 8); + vst2q_s16(pp, _r01); + p0 += max_jj * batch * 8; + pp += 16; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; +#endif // __ARM_NEON + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; +#else + pp[0] = p0[0]; + pp[1] = p0[2]; + pp[2] = p0[1]; + pp[3] = p0[3]; +#endif + p0 += max_jj * batch * 2; + pp += 4; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch; + pp += 2; + } + } + for (; jj < max_jj; jj++) + { + const short* p0 = B; + + int kk = 0; +#if __ARM_NEON + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #128] \n" + "ld1 {v0.8h}, [%0] \n" + "st1 {v0.8h}, [%1], #16 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0"); + p0 += max_jj * batch * 8; +#else // __aarch64__ + asm volatile( + "pld [%0, #128] \n" + "vld1.s16 {d0-d1}, [%0] \n" + "vst1.s16 {d0-d1}, [%1]! \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "q0"); + p0 += max_jj * batch * 8; +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int16x8_t _r0 = vld1q_s16(p0); + vst1q_s16(pp, _r0); + p0 += max_jj * batch * 8; + pp += 8; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; +#endif // __ARM_NEON + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch * 2; + pp += 2; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += max_jj * batch; + pp += 1; + } + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& top_blob, int batch, int max_ii, int max_jj, int k, int max_kk) +{ + // return; + // NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d", max_ii, max_jj, max_kk); + + int* outptr = top_blob; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #512] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%0], #64 \n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%0], #64 \n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #320 \n" + "b 1f \n" + + "0: \n" + "eor v8.16b, v8.16b, v8.16b \n" + "eor v9.16b, v9.16b, v9.16b \n" + "eor v10.16b, v10.16b, v10.16b \n" + "eor v11.16b, v11.16b, v11.16b \n" + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + ".align 4 \n" + "2: \n" + "smlal v8.4s, v4.4h, v0.h[0] \n" + "smlal v10.4s, v4.4h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal2 v9.4s, v4.8h, v0.h[0] \n" + "smlal2 v11.4s, v4.8h, v0.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal v12.4s, v4.4h, v0.h[2] \n" + "smlal v14.4s, v4.4h, v0.h[3] \n" + "smlal2 v13.4s, v4.8h, v0.h[2] \n" + "smlal2 v15.4s, v4.8h, v0.h[3] \n" + "smlal v16.4s, v4.4h, v0.h[4] \n" + "smlal v18.4s, v4.4h, v0.h[5] \n" + "smlal2 v17.4s, v4.8h, v0.h[4] \n" + "smlal2 v19.4s, v4.8h, v0.h[5] \n" + "smlal v20.4s, v4.4h, v0.h[6] \n" + "smlal v22.4s, v4.4h, v0.h[7] \n" + "smlal2 v21.4s, v4.8h, v0.h[6] \n" + "smlal2 v23.4s, v4.8h, v0.h[7] \n" + "smlal v24.4s, v4.4h, v1.h[0] \n" + "smlal v26.4s, v4.4h, v1.h[1] \n" + "smlal2 v25.4s, v4.8h, v1.h[0] \n" + "smlal2 v27.4s, v4.8h, v1.h[1] \n" + "smlal v28.4s, v4.4h, v1.h[2] \n" + "smlal v30.4s, v4.4h, v1.h[3] \n" + "smlal2 v29.4s, v4.8h, v1.h[2] \n" + "smlal2 v31.4s, v4.8h, v1.h[3] \n" + "smlal v8.4s, v5.4h, v1.h[4] \n" + "smlal v10.4s, v5.4h, v1.h[5] \n" + "smlal2 v9.4s, v5.8h, v1.h[4] \n" + "smlal2 v11.4s, v5.8h, v1.h[5] \n" + "smlal v12.4s, v5.4h, v1.h[6] \n" + "smlal v14.4s, v5.4h, v1.h[7] \n" + "smlal2 v13.4s, v5.8h, v1.h[6] \n" + "smlal2 v15.4s, v5.8h, v1.h[7] \n" + "smlal v16.4s, v5.4h, v2.h[0] \n" + "smlal v18.4s, v5.4h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal2 v17.4s, v5.8h, v2.h[0] \n" + "smlal2 v19.4s, v5.8h, v2.h[1] \n" + "smlal v20.4s, v5.4h, v2.h[2] \n" + "smlal v22.4s, v5.4h, v2.h[3] \n" + "smlal2 v21.4s, v5.8h, v2.h[2] \n" + "smlal2 v23.4s, v5.8h, v2.h[3] \n" + "smlal v24.4s, v5.4h, v2.h[4] \n" + "smlal v26.4s, v5.4h, v2.h[5] \n" + "smlal2 v25.4s, v5.8h, v2.h[4] \n" + "smlal2 v27.4s, v5.8h, v2.h[5] \n" + "smlal v28.4s, v5.4h, v2.h[6] \n" + "smlal v30.4s, v5.4h, v2.h[7] \n" + "smlal2 v29.4s, v5.8h, v2.h[6] \n" + "smlal2 v31.4s, v5.8h, v2.h[7] \n" + "smlal v8.4s, v6.4h, v3.h[0] \n" + "smlal v10.4s, v6.4h, v3.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v9.4s, v6.8h, v3.h[0] \n" + "smlal2 v11.4s, v6.8h, v3.h[1] \n" + "smlal v12.4s, v6.4h, v3.h[2] \n" + "smlal v14.4s, v6.4h, v3.h[3] \n" + "smlal2 v13.4s, v6.8h, v3.h[2] \n" + "smlal2 v15.4s, v6.8h, v3.h[3] \n" + "smlal v16.4s, v6.4h, v3.h[4] \n" + "smlal v18.4s, v6.4h, v3.h[5] \n" + "smlal2 v17.4s, v6.8h, v3.h[4] \n" + "smlal2 v19.4s, v6.8h, v3.h[5] \n" + "smlal v20.4s, v6.4h, v3.h[6] \n" + "smlal v22.4s, v6.4h, v3.h[7] \n" + "smlal2 v21.4s, v6.8h, v3.h[6] \n" + "smlal2 v23.4s, v6.8h, v3.h[7] \n" + "smlal v24.4s, v6.4h, v0.h[0] \n" + "smlal v26.4s, v6.4h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal2 v25.4s, v6.8h, v0.h[0] \n" + "smlal2 v27.4s, v6.8h, v0.h[1] \n" + "smlal v28.4s, v6.4h, v0.h[2] \n" + "smlal v30.4s, v6.4h, v0.h[3] \n" + "smlal2 v29.4s, v6.8h, v0.h[2] \n" + "smlal2 v31.4s, v6.8h, v0.h[3] \n" + "smlal v8.4s, v7.4h, v0.h[4] \n" + "smlal v10.4s, v7.4h, v0.h[5] \n" + "smlal2 v9.4s, v7.8h, v0.h[4] \n" + "smlal2 v11.4s, v7.8h, v0.h[5] \n" + "smlal v12.4s, v7.4h, v0.h[6] \n" + "smlal v14.4s, v7.4h, v0.h[7] \n" + "smlal2 v13.4s, v7.8h, v0.h[6] \n" + "smlal2 v15.4s, v7.8h, v0.h[7] \n" + "smlal v16.4s, v7.4h, v1.h[0] \n" + "smlal v18.4s, v7.4h, v1.h[1] \n" + "smlal2 v17.4s, v7.8h, v1.h[0] \n" + "smlal2 v19.4s, v7.8h, v1.h[1] \n" + "smlal v20.4s, v7.4h, v1.h[2] \n" + "smlal v22.4s, v7.4h, v1.h[3] \n" + "smlal2 v21.4s, v7.8h, v1.h[2] \n" + "smlal2 v23.4s, v7.8h, v1.h[3] \n" + "smlal v24.4s, v7.4h, v1.h[4] \n" + "smlal v26.4s, v7.4h, v1.h[5] \n" + "smlal2 v25.4s, v7.8h, v1.h[4] \n" + "smlal2 v27.4s, v7.8h, v1.h[5] \n" + "smlal v28.4s, v7.4h, v1.h[6] \n" + "smlal v30.4s, v7.4h, v1.h[7] \n" + "smlal2 v29.4s, v7.8h, v1.h[6] \n" + "smlal2 v31.4s, v7.8h, v1.h[7] \n" + "smlal v8.4s, v4.4h, v2.h[0] \n" + "smlal v10.4s, v4.4h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal2 v9.4s, v4.8h, v2.h[0] \n" + "smlal2 v11.4s, v4.8h, v2.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal v12.4s, v4.4h, v2.h[2] \n" + "smlal v14.4s, v4.4h, v2.h[3] \n" + "smlal2 v13.4s, v4.8h, v2.h[2] \n" + "smlal2 v15.4s, v4.8h, v2.h[3] \n" + "smlal v16.4s, v4.4h, v2.h[4] \n" + "smlal v18.4s, v4.4h, v2.h[5] \n" + "smlal2 v17.4s, v4.8h, v2.h[4] \n" + "smlal2 v19.4s, v4.8h, v2.h[5] \n" + "smlal v20.4s, v4.4h, v2.h[6] \n" + "smlal v22.4s, v4.4h, v2.h[7] \n" + "smlal2 v21.4s, v4.8h, v2.h[6] \n" + "smlal2 v23.4s, v4.8h, v2.h[7] \n" + "smlal v24.4s, v4.4h, v3.h[0] \n" + "smlal v26.4s, v4.4h, v3.h[1] \n" + "smlal2 v25.4s, v4.8h, v3.h[0] \n" + "smlal2 v27.4s, v4.8h, v3.h[1] \n" + "smlal v28.4s, v4.4h, v3.h[2] \n" + "smlal v30.4s, v4.4h, v3.h[3] \n" + "smlal2 v29.4s, v4.8h, v3.h[2] \n" + "smlal2 v31.4s, v4.8h, v3.h[3] \n" + "smlal v8.4s, v5.4h, v3.h[4] \n" + "smlal v10.4s, v5.4h, v3.h[5] \n" + "smlal2 v9.4s, v5.8h, v3.h[4] \n" + "smlal2 v11.4s, v5.8h, v3.h[5] \n" + "smlal v12.4s, v5.4h, v3.h[6] \n" + "smlal v14.4s, v5.4h, v3.h[7] \n" + "smlal2 v13.4s, v5.8h, v3.h[6] \n" + "smlal2 v15.4s, v5.8h, v3.h[7] \n" + "smlal v16.4s, v5.4h, v0.h[0] \n" + "smlal v18.4s, v5.4h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal2 v17.4s, v5.8h, v0.h[0] \n" + "smlal2 v19.4s, v5.8h, v0.h[1] \n" + "smlal v20.4s, v5.4h, v0.h[2] \n" + "smlal v22.4s, v5.4h, v0.h[3] \n" + "smlal2 v21.4s, v5.8h, v0.h[2] \n" + "smlal2 v23.4s, v5.8h, v0.h[3] \n" + "smlal v24.4s, v5.4h, v0.h[4] \n" + "smlal v26.4s, v5.4h, v0.h[5] \n" + "smlal2 v25.4s, v5.8h, v0.h[4] \n" + "smlal2 v27.4s, v5.8h, v0.h[5] \n" + "smlal v28.4s, v5.4h, v0.h[6] \n" + "smlal v30.4s, v5.4h, v0.h[7] \n" + "smlal2 v29.4s, v5.8h, v0.h[6] \n" + "smlal2 v31.4s, v5.8h, v0.h[7] \n" + "smlal v8.4s, v6.4h, v1.h[0] \n" + "smlal v10.4s, v6.4h, v1.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v9.4s, v6.8h, v1.h[0] \n" + "smlal2 v11.4s, v6.8h, v1.h[1] \n" + "smlal v12.4s, v6.4h, v1.h[2] \n" + "smlal v14.4s, v6.4h, v1.h[3] \n" + "smlal2 v13.4s, v6.8h, v1.h[2] \n" + "smlal2 v15.4s, v6.8h, v1.h[3] \n" + "smlal v16.4s, v6.4h, v1.h[4] \n" + "smlal v18.4s, v6.4h, v1.h[5] \n" + "smlal2 v17.4s, v6.8h, v1.h[4] \n" + "smlal2 v19.4s, v6.8h, v1.h[5] \n" + "smlal v20.4s, v6.4h, v1.h[6] \n" + "smlal v22.4s, v6.4h, v1.h[7] \n" + "smlal2 v21.4s, v6.8h, v1.h[6] \n" + "smlal2 v23.4s, v6.8h, v1.h[7] \n" + "smlal v24.4s, v6.4h, v2.h[0] \n" + "smlal v26.4s, v6.4h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal2 v25.4s, v6.8h, v2.h[0] \n" + "smlal2 v27.4s, v6.8h, v2.h[1] \n" + "smlal v28.4s, v6.4h, v2.h[2] \n" + "smlal v30.4s, v6.4h, v2.h[3] \n" + "smlal2 v29.4s, v6.8h, v2.h[2] \n" + "smlal2 v31.4s, v6.8h, v2.h[3] \n" + "smlal v8.4s, v7.4h, v2.h[4] \n" + "smlal v10.4s, v7.4h, v2.h[5] \n" + "smlal2 v9.4s, v7.8h, v2.h[4] \n" + "smlal2 v11.4s, v7.8h, v2.h[5] \n" + "smlal v12.4s, v7.4h, v2.h[6] \n" + "smlal v14.4s, v7.4h, v2.h[7] \n" + "smlal2 v13.4s, v7.8h, v2.h[6] \n" + "smlal2 v15.4s, v7.8h, v2.h[7] \n" + "smlal v16.4s, v7.4h, v3.h[0] \n" + "smlal v18.4s, v7.4h, v3.h[1] \n" + "smlal2 v17.4s, v7.8h, v3.h[0] \n" + "smlal2 v19.4s, v7.8h, v3.h[1] \n" + "smlal v20.4s, v7.4h, v3.h[2] \n" + "smlal v22.4s, v7.4h, v3.h[3] \n" + "smlal2 v21.4s, v7.8h, v3.h[2] \n" + "smlal2 v23.4s, v7.8h, v3.h[3] \n" + "smlal v24.4s, v7.4h, v3.h[4] \n" + "smlal v26.4s, v7.4h, v3.h[5] \n" + "smlal2 v25.4s, v7.8h, v3.h[4] \n" + "smlal2 v27.4s, v7.8h, v3.h[5] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v7.4h, v3.h[6] \n" + "smlal v30.4s, v7.4h, v3.h[7] \n" + "smlal2 v29.4s, v7.8h, v3.h[6] \n" + "smlal2 v31.4s, v7.8h, v3.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #32 \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.4h, v1.4h, v2.4h}, [%2], #24 \n" + "smlal v8.4s, v4.4h, v0.h[0] \n" + "smlal v10.4s, v4.4h, v0.h[1] \n" + "smlal2 v9.4s, v4.8h, v0.h[0] \n" + "smlal2 v11.4s, v4.8h, v0.h[1] \n" + "smlal v12.4s, v4.4h, v0.h[2] \n" + "smlal v14.4s, v4.4h, v0.h[3] \n" + "smlal2 v13.4s, v4.8h, v0.h[2] \n" + "smlal2 v15.4s, v4.8h, v0.h[3] \n" + "smlal v16.4s, v4.4h, v1.h[0] \n" + "smlal v18.4s, v4.4h, v1.h[1] \n" + "smlal2 v17.4s, v4.8h, v1.h[0] \n" + "smlal2 v19.4s, v4.8h, v1.h[1] \n" + "smlal v20.4s, v4.4h, v1.h[2] \n" + "smlal v22.4s, v4.4h, v1.h[3] \n" + "smlal2 v21.4s, v4.8h, v1.h[2] \n" + "smlal2 v23.4s, v4.8h, v1.h[3] \n" + "smlal v24.4s, v4.4h, v2.h[0] \n" + "smlal v26.4s, v4.4h, v2.h[1] \n" + "smlal2 v25.4s, v4.8h, v2.h[0] \n" + "smlal2 v27.4s, v4.8h, v2.h[1] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v4.4h, v2.h[2] \n" + "smlal v30.4s, v4.4h, v2.h[3] \n" + "smlal2 v29.4s, v4.8h, v2.h[2] \n" + "smlal2 v31.4s, v4.8h, v2.h[3] \n" + "bne 4b \n" + + "5: \n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%0], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%0], #64 \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + int32x4_t _sumc; + int32x4_t _sumd; + int32x4_t _sume; + int32x4_t _sumf; + int32x4_t _sumg; + int32x4_t _sumh; + int32x4_t _sumi; + int32x4_t _sumj; + int32x4_t _sumk; + int32x4_t _suml; + int32x4_t _summ; + int32x4_t _sumn; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + _sumc = vdupq_n_s32(0); + _sumd = vdupq_n_s32(0); + _sume = vdupq_n_s32(0); + _sumf = vdupq_n_s32(0); + _sumg = vdupq_n_s32(0); + _sumh = vdupq_n_s32(0); + _sumi = vdupq_n_s32(0); + _sumj = vdupq_n_s32(0); + _sumk = vdupq_n_s32(0); + _suml = vdupq_n_s32(0); + _summ = vdupq_n_s32(0); + _sumn = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + _sumc = vld1q_s32(outptr + 48); + _sumd = vld1q_s32(outptr + 52); + _sume = vld1q_s32(outptr + 56); + _sumf = vld1q_s32(outptr + 60); + _sumg = vld1q_s32(outptr + 64); + _sumh = vld1q_s32(outptr + 68); + _sumi = vld1q_s32(outptr + 72); + _sumj = vld1q_s32(outptr + 76); + _sumk = vld1q_s32(outptr + 80); + _suml = vld1q_s32(outptr + 84); + _summ = vld1q_s32(outptr + 88); + _sumn = vld1q_s32(outptr + 92); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vld1_s16(pB + 8); + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_pA), _pB, 0); + _sum1 = vmlal_laneq_s16(_sum1, vget_high_s16(_pA), _pB, 0); + _sum2 = vmlal_laneq_s16(_sum2, vget_low_s16(_pA), _pB, 1); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_pA), _pB, 1); + _sum4 = vmlal_laneq_s16(_sum4, vget_low_s16(_pA), _pB, 2); + _sum5 = vmlal_laneq_s16(_sum5, vget_high_s16(_pA), _pB, 2); + _sum6 = vmlal_laneq_s16(_sum6, vget_low_s16(_pA), _pB, 3); + _sum7 = vmlal_laneq_s16(_sum7, vget_high_s16(_pA), _pB, 3); + _sum8 = vmlal_laneq_s16(_sum8, vget_low_s16(_pA), _pB, 4); + _sum9 = vmlal_laneq_s16(_sum9, vget_high_s16(_pA), _pB, 4); + _suma = vmlal_laneq_s16(_suma, vget_low_s16(_pA), _pB, 5); + _sumb = vmlal_laneq_s16(_sumb, vget_high_s16(_pA), _pB, 5); + _sumc = vmlal_laneq_s16(_sumc, vget_low_s16(_pA), _pB, 6); + _sumd = vmlal_laneq_s16(_sumd, vget_high_s16(_pA), _pB, 6); + _sume = vmlal_laneq_s16(_sume, vget_low_s16(_pA), _pB, 7); + _sumf = vmlal_laneq_s16(_sumf, vget_high_s16(_pA), _pB, 7); + _sumg = vmlal_lane_s16(_sumg, vget_low_s16(_pA), _pB2, 0); + _sumh = vmlal_lane_s16(_sumh, vget_high_s16(_pA), _pB2, 0); + _sumi = vmlal_lane_s16(_sumi, vget_low_s16(_pA), _pB2, 1); + _sumj = vmlal_lane_s16(_sumj, vget_high_s16(_pA), _pB2, 1); + _sumk = vmlal_lane_s16(_sumk, vget_low_s16(_pA), _pB2, 2); + _suml = vmlal_lane_s16(_suml, vget_high_s16(_pA), _pB2, 2); + _summ = vmlal_lane_s16(_summ, vget_low_s16(_pA), _pB2, 3); + _sumn = vmlal_lane_s16(_sumn, vget_high_s16(_pA), _pB2, 3); + pA += 8; + pB += 12; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + vst1q_s32(outptr + 48, _sumc); + vst1q_s32(outptr + 52, _sumd); + vst1q_s32(outptr + 56, _sume); + vst1q_s32(outptr + 60, _sumf); + vst1q_s32(outptr + 64, _sumg); + vst1q_s32(outptr + 68, _sumh); + vst1q_s32(outptr + 72, _sumi); + vst1q_s32(outptr + 76, _sumj); + vst1q_s32(outptr + 80, _sumk); + vst1q_s32(outptr + 84, _suml); + vst1q_s32(outptr + 88, _summ); + vst1q_s32(outptr + 92, _sumn); + outptr += 96; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #512] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #192 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + ".align 4 \n" + "2: \n" + "smlal v16.4s, v4.4h, v0.h[0] \n" + "smlal v18.4s, v4.4h, v0.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v17.4s, v4.8h, v0.h[0] \n" + "smlal2 v19.4s, v4.8h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal v20.4s, v4.4h, v0.h[2] \n" + "smlal v22.4s, v4.4h, v0.h[3] \n" + "smlal2 v21.4s, v4.8h, v0.h[2] \n" + "smlal2 v23.4s, v4.8h, v0.h[3] \n" + "smlal v24.4s, v4.4h, v0.h[4] \n" + "smlal v26.4s, v4.4h, v0.h[5] \n" + "smlal2 v25.4s, v4.8h, v0.h[4] \n" + "smlal2 v27.4s, v4.8h, v0.h[5] \n" + "smlal v28.4s, v4.4h, v0.h[6] \n" + "smlal v30.4s, v4.4h, v0.h[7] \n" + "smlal2 v29.4s, v4.8h, v0.h[6] \n" + "smlal2 v31.4s, v4.8h, v0.h[7] \n" + "smlal v16.4s, v5.4h, v1.h[0] \n" + "smlal v18.4s, v5.4h, v1.h[1] \n" + "smlal2 v17.4s, v5.8h, v1.h[0] \n" + "smlal2 v19.4s, v5.8h, v1.h[1] \n" + "smlal v20.4s, v5.4h, v1.h[2] \n" + "smlal v22.4s, v5.4h, v1.h[3] \n" + "smlal2 v21.4s, v5.8h, v1.h[2] \n" + "smlal2 v23.4s, v5.8h, v1.h[3] \n" + "smlal v24.4s, v5.4h, v1.h[4] \n" + "smlal v26.4s, v5.4h, v1.h[5] \n" + "smlal2 v25.4s, v5.8h, v1.h[4] \n" + "smlal2 v27.4s, v5.8h, v1.h[5] \n" + "smlal v28.4s, v5.4h, v1.h[6] \n" + "smlal v30.4s, v5.4h, v1.h[7] \n" + "smlal2 v29.4s, v5.8h, v1.h[6] \n" + "smlal2 v31.4s, v5.8h, v1.h[7] \n" + "smlal v16.4s, v6.4h, v2.h[0] \n" + "smlal v18.4s, v6.4h, v2.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v17.4s, v6.8h, v2.h[0] \n" + "smlal2 v19.4s, v6.8h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal v20.4s, v6.4h, v2.h[2] \n" + "smlal v22.4s, v6.4h, v2.h[3] \n" + "smlal2 v21.4s, v6.8h, v2.h[2] \n" + "smlal2 v23.4s, v6.8h, v2.h[3] \n" + "smlal v24.4s, v6.4h, v2.h[4] \n" + "smlal v26.4s, v6.4h, v2.h[5] \n" + "smlal2 v25.4s, v6.8h, v2.h[4] \n" + "smlal2 v27.4s, v6.8h, v2.h[5] \n" + "smlal v28.4s, v6.4h, v2.h[6] \n" + "smlal v30.4s, v6.4h, v2.h[7] \n" + "smlal2 v29.4s, v6.8h, v2.h[6] \n" + "smlal2 v31.4s, v6.8h, v2.h[7] \n" + "smlal v16.4s, v7.4h, v3.h[0] \n" + "smlal v18.4s, v7.4h, v3.h[1] \n" + "smlal2 v17.4s, v7.8h, v3.h[0] \n" + "smlal2 v19.4s, v7.8h, v3.h[1] \n" + "smlal v20.4s, v7.4h, v3.h[2] \n" + "smlal v22.4s, v7.4h, v3.h[3] \n" + "smlal2 v21.4s, v7.8h, v3.h[2] \n" + "smlal2 v23.4s, v7.8h, v3.h[3] \n" + "subs w4, w4, #1 \n" + "smlal v24.4s, v7.4h, v3.h[4] \n" + "smlal v26.4s, v7.4h, v3.h[5] \n" + "smlal2 v25.4s, v7.8h, v3.h[4] \n" + "smlal2 v27.4s, v7.8h, v3.h[5] \n" + "smlal v28.4s, v7.4h, v3.h[6] \n" + "smlal v30.4s, v7.4h, v3.h[7] \n" + "smlal2 v29.4s, v7.8h, v3.h[6] \n" + "smlal2 v31.4s, v7.8h, v3.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #32 \n" + + "3: \n" + "and w4, %w6, #3 \n" // w4 = remain = max_kk & 3 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.8h}, [%2], #16 \n" + "smlal v16.4s, v4.4h, v0.h[0] \n" + "smlal v18.4s, v4.4h, v0.h[1] \n" + "smlal2 v17.4s, v4.8h, v0.h[0] \n" + "smlal2 v19.4s, v4.8h, v0.h[1] \n" + "smlal v20.4s, v4.4h, v0.h[2] \n" + "smlal v22.4s, v4.4h, v0.h[3] \n" + "smlal2 v21.4s, v4.8h, v0.h[2] \n" + "smlal2 v23.4s, v4.8h, v0.h[3] \n" + "subs w4, w4, #1 \n" + "smlal v24.4s, v4.4h, v0.h[4] \n" + "smlal v26.4s, v4.4h, v0.h[5] \n" + "smlal2 v25.4s, v4.8h, v0.h[4] \n" + "smlal2 v27.4s, v4.8h, v0.h[5] \n" + "smlal v28.4s, v4.4h, v0.h[6] \n" + "smlal v30.4s, v4.4h, v0.h[7] \n" + "smlal2 v29.4s, v4.8h, v0.h[6] \n" + "smlal2 v31.4s, v4.8h, v0.h[7] \n" + "bne 4b \n" + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + int32x4_t _sumc; + int32x4_t _sumd; + int32x4_t _sume; + int32x4_t _sumf; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + _sumc = vdupq_n_s32(0); + _sumd = vdupq_n_s32(0); + _sume = vdupq_n_s32(0); + _sumf = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + _sumc = vld1q_s32(outptr + 48); + _sumd = vld1q_s32(outptr + 52); + _sume = vld1q_s32(outptr + 56); + _sumf = vld1q_s32(outptr + 60); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_pA), _pB, 0); + _sum1 = vmlal_laneq_s16(_sum1, vget_high_s16(_pA), _pB, 0); + _sum2 = vmlal_laneq_s16(_sum2, vget_low_s16(_pA), _pB, 1); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_pA), _pB, 1); + _sum4 = vmlal_laneq_s16(_sum4, vget_low_s16(_pA), _pB, 2); + _sum5 = vmlal_laneq_s16(_sum5, vget_high_s16(_pA), _pB, 2); + _sum6 = vmlal_laneq_s16(_sum6, vget_low_s16(_pA), _pB, 3); + _sum7 = vmlal_laneq_s16(_sum7, vget_high_s16(_pA), _pB, 3); + _sum8 = vmlal_laneq_s16(_sum8, vget_low_s16(_pA), _pB, 4); + _sum9 = vmlal_laneq_s16(_sum9, vget_high_s16(_pA), _pB, 4); + _suma = vmlal_laneq_s16(_suma, vget_low_s16(_pA), _pB, 5); + _sumb = vmlal_laneq_s16(_sumb, vget_high_s16(_pA), _pB, 5); + _sumc = vmlal_laneq_s16(_sumc, vget_low_s16(_pA), _pB, 6); + _sumd = vmlal_laneq_s16(_sumd, vget_high_s16(_pA), _pB, 6); + _sume = vmlal_laneq_s16(_sume, vget_low_s16(_pA), _pB, 7); + _sumf = vmlal_laneq_s16(_sumf, vget_high_s16(_pA), _pB, 7); + pA += 8; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + vst1q_s32(outptr + 48, _sumc); + vst1q_s32(outptr + 52, _sumd); + vst1q_s32(outptr + 56, _sume); + vst1q_s32(outptr + 60, _sumf); + outptr += 64; +#endif // NCNN_GNU_INLINE_ASM + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #384] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #128 \n" + "b 1f \n" + + "0: \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + ".align 4 \n" + "2: \n" + "smlal v20.4s, v6.4h, v0.h[0] \n" + "smlal v22.4s, v6.4h, v0.h[1] \n" + "ld1 {v8.8h, v9.8h}, [%1], #32 \n" + "smlal2 v21.4s, v6.8h, v0.h[0] \n" + "smlal2 v23.4s, v6.8h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal v24.4s, v6.4h, v0.h[2] \n" + "smlal v26.4s, v6.4h, v0.h[3] \n" + "smlal2 v25.4s, v6.8h, v0.h[2] \n" + "smlal2 v27.4s, v6.8h, v0.h[3] \n" + "smlal v28.4s, v6.4h, v0.h[4] \n" + "smlal v30.4s, v6.4h, v0.h[5] \n" + "smlal2 v29.4s, v6.8h, v0.h[4] \n" + "smlal2 v31.4s, v6.8h, v0.h[5] \n" + "smlal v20.4s, v7.4h, v0.h[6] \n" + "smlal v22.4s, v7.4h, v0.h[7] \n" + "smlal2 v21.4s, v7.8h, v0.h[6] \n" + "smlal2 v23.4s, v7.8h, v0.h[7] \n" + "smlal v24.4s, v7.4h, v1.h[0] \n" + "smlal v26.4s, v7.4h, v1.h[1] \n" + "smlal2 v25.4s, v7.8h, v1.h[0] \n" + "smlal2 v27.4s, v7.8h, v1.h[1] \n" + "smlal v28.4s, v7.4h, v1.h[2] \n" + "smlal v30.4s, v7.4h, v1.h[3] \n" + "smlal2 v29.4s, v7.8h, v1.h[2] \n" + "smlal2 v31.4s, v7.8h, v1.h[3] \n" + "smlal v20.4s, v8.4h, v1.h[4] \n" + "smlal v22.4s, v8.4h, v1.h[5] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v21.4s, v8.8h, v1.h[4] \n" + "smlal2 v23.4s, v8.8h, v1.h[5] \n" + "smlal v24.4s, v8.4h, v1.h[6] \n" + "smlal v26.4s, v8.4h, v1.h[7] \n" + "smlal2 v25.4s, v8.8h, v1.h[6] \n" + "smlal2 v27.4s, v8.8h, v1.h[7] \n" + "smlal v28.4s, v8.4h, v2.h[0] \n" + "smlal v30.4s, v8.4h, v2.h[1] \n" + "ld1 {v4.8h, v5.8h}, [%2], #32 \n" + "smlal2 v29.4s, v8.8h, v2.h[0] \n" + "smlal2 v31.4s, v8.8h, v2.h[1] \n" + "smlal v20.4s, v9.4h, v2.h[2] \n" + "smlal v22.4s, v9.4h, v2.h[3] \n" + "smlal2 v21.4s, v9.8h, v2.h[2] \n" + "smlal2 v23.4s, v9.8h, v2.h[3] \n" + "smlal v24.4s, v9.4h, v2.h[4] \n" + "smlal v26.4s, v9.4h, v2.h[5] \n" + "smlal2 v25.4s, v9.8h, v2.h[4] \n" + "smlal2 v27.4s, v9.8h, v2.h[5] \n" + "smlal v28.4s, v9.4h, v2.h[6] \n" + "smlal v30.4s, v9.4h, v2.h[7] \n" + "smlal2 v29.4s, v9.8h, v2.h[6] \n" + "smlal2 v31.4s, v9.8h, v2.h[7] \n" + "smlal v20.4s, v6.4h, v3.h[0] \n" + "smlal v22.4s, v6.4h, v3.h[1] \n" + "ld1 {v8.8h, v9.8h}, [%1], #32 \n" + "smlal2 v21.4s, v6.8h, v3.h[0] \n" + "smlal2 v23.4s, v6.8h, v3.h[1] \n" + "smlal v24.4s, v6.4h, v3.h[2] \n" + "smlal v26.4s, v6.4h, v3.h[3] \n" + "smlal2 v25.4s, v6.8h, v3.h[2] \n" + "smlal2 v27.4s, v6.8h, v3.h[3] \n" + "smlal v28.4s, v6.4h, v3.h[4] \n" + "smlal v30.4s, v6.4h, v3.h[5] \n" + "smlal2 v29.4s, v6.8h, v3.h[4] \n" + "smlal2 v31.4s, v6.8h, v3.h[5] \n" + "smlal v20.4s, v7.4h, v3.h[6] \n" + "smlal v22.4s, v7.4h, v3.h[7] \n" + "smlal2 v21.4s, v7.8h, v3.h[6] \n" + "smlal2 v23.4s, v7.8h, v3.h[7] \n" + "smlal v24.4s, v7.4h, v4.h[0] \n" + "smlal v26.4s, v7.4h, v4.h[1] \n" + "prfm pldl1keep, [%2, #384] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal2 v25.4s, v7.8h, v4.h[0] \n" + "smlal2 v27.4s, v7.8h, v4.h[1] \n" + "smlal v28.4s, v7.4h, v4.h[2] \n" + "smlal v30.4s, v7.4h, v4.h[3] \n" + "smlal2 v29.4s, v7.8h, v4.h[2] \n" + "smlal2 v31.4s, v7.8h, v4.h[3] \n" + "smlal v20.4s, v8.4h, v4.h[4] \n" + "smlal v22.4s, v8.4h, v4.h[5] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v21.4s, v8.8h, v4.h[4] \n" + "smlal2 v23.4s, v8.8h, v4.h[5] \n" + "smlal v24.4s, v8.4h, v4.h[6] \n" + "smlal v26.4s, v8.4h, v4.h[7] \n" + "smlal2 v25.4s, v8.8h, v4.h[6] \n" + "smlal2 v27.4s, v8.8h, v4.h[7] \n" + "smlal v28.4s, v8.4h, v5.h[0] \n" + "smlal v30.4s, v8.4h, v5.h[1] \n" + "smlal2 v29.4s, v8.8h, v5.h[0] \n" + "smlal2 v31.4s, v8.8h, v5.h[1] \n" + "smlal v20.4s, v9.4h, v5.h[2] \n" + "smlal v22.4s, v9.4h, v5.h[3] \n" + "smlal2 v21.4s, v9.8h, v5.h[2] \n" + "smlal2 v23.4s, v9.8h, v5.h[3] \n" + "smlal v24.4s, v9.4h, v5.h[4] \n" + "smlal v26.4s, v9.4h, v5.h[5] \n" + "smlal2 v25.4s, v9.8h, v5.h[4] \n" + "smlal2 v27.4s, v9.8h, v5.h[5] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v9.4h, v5.h[6] \n" + "smlal v30.4s, v9.4h, v5.h[7] \n" + "smlal2 v29.4s, v9.8h, v5.h[6] \n" + "smlal2 v31.4s, v9.8h, v5.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #32 \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.8h}, [%2] \n" + "add %2, %2, #12 \n" + "smlal v20.4s, v4.4h, v0.h[0] \n" + "smlal v22.4s, v4.4h, v0.h[1] \n" + "smlal2 v21.4s, v4.8h, v0.h[0] \n" + "smlal2 v23.4s, v4.8h, v0.h[1] \n" + "smlal v24.4s, v4.4h, v0.h[2] \n" + "smlal v26.4s, v4.4h, v0.h[3] \n" + "smlal2 v25.4s, v4.8h, v0.h[2] \n" + "smlal2 v27.4s, v4.8h, v0.h[3] \n" + "smlal v28.4s, v4.4h, v0.h[4] \n" + "smlal v30.4s, v4.4h, v0.h[5] \n" + "smlal2 v29.4s, v4.8h, v0.h[4] \n" + "smlal2 v31.4s, v4.8h, v0.h[5] \n" + "subs w4, w4, #1 \n" + "bne 4b \n" + + "5: \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "pld [%1, #512] \n" + "pld [%2, #384] \n" + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0!, {d8-d15} \n" + "vldm %0, {d16-d31} \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "veor q4, q4 \n" + "veor q5, q5 \n" + "veor q6, q6 \n" + "veor q7, q7 \n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #3 \n" // r4 = max_kk >> 3 + "cmp r4, #0 \n" + "beq 3f \n" + + "vld1.s16 {d4-d5}, [%1]! \n" + "vld1.s16 {d0-d1}, [%2]! \n" + ".align 4 \n" + "2: \n" + "vmlal.s16 q4, d4, d0[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q6, d4, d0[1] \n" + "vld1.s16 {d2-d3}, [%2]! \n" + "vmlal.s16 q8, d4, d0[2] \n" + "vmlal.s16 q10, d4, d0[3] \n" + "vmlal.s16 q5, d5, d0[0] \n" + "vmlal.s16 q7, d5, d0[1] \n" + "vmlal.s16 q9, d5, d0[2] \n" + "vmlal.s16 q11, d5, d0[3] \n" + "vmlal.s16 q12, d4, d1[0] \n" + "vmlal.s16 q14, d4, d1[1] \n" + "vmlal.s16 q13, d5, d1[0] \n" + "vmlal.s16 q15, d5, d1[1] \n" + "vmlal.s16 q4, d6, d1[2] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q6, d6, d1[3] \n" + "vmlal.s16 q5, d7, d1[2] \n" + "vmlal.s16 q7, d7, d1[3] \n" + "vmlal.s16 q8, d6, d2[0] \n" + "pld [%2, #384] \n" + "vld1.s16 {d0-d1}, [%2]! \n" + "vmlal.s16 q10, d6, d2[1] \n" + "vmlal.s16 q12, d6, d2[2] \n" + "vmlal.s16 q14, d6, d2[3] \n" + "vmlal.s16 q9, d7, d2[0] \n" + "vmlal.s16 q11, d7, d2[1] \n" + "vmlal.s16 q13, d7, d2[2] \n" + "vmlal.s16 q15, d7, d2[3] \n" + "vmlal.s16 q4, d4, d3[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q6, d4, d3[1] \n" + "vmlal.s16 q8, d4, d3[2] \n" + "vmlal.s16 q10, d4, d3[3] \n" + "vmlal.s16 q5, d5, d3[0] \n" + "vmlal.s16 q7, d5, d3[1] \n" + "vmlal.s16 q9, d5, d3[2] \n" + "vmlal.s16 q11, d5, d3[3] \n" + "vmlal.s16 q12, d4, d0[0] \n" + "vld1.s16 {d2-d3}, [%2]! \n" + "vmlal.s16 q14, d4, d0[1] \n" + "vmlal.s16 q13, d5, d0[0] \n" + "vmlal.s16 q15, d5, d0[1] \n" + "vmlal.s16 q4, d6, d0[2] \n" + "pld [%1, #512] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q6, d6, d0[3] \n" + "vmlal.s16 q5, d7, d0[2] \n" + "vmlal.s16 q7, d7, d0[3] \n" + "vmlal.s16 q8, d6, d1[0] \n" + "vmlal.s16 q10, d6, d1[1] \n" + "vmlal.s16 q12, d6, d1[2] \n" + "vmlal.s16 q14, d6, d1[3] \n" + "vmlal.s16 q9, d7, d1[0] \n" + "vmlal.s16 q11, d7, d1[1] \n" + "vmlal.s16 q13, d7, d1[2] \n" + "vmlal.s16 q15, d7, d1[3] \n" + "vmlal.s16 q4, d4, d2[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q6, d4, d2[1] \n" + "vld1.s16 {d0-d1}, [%2]! \n" + "vmlal.s16 q8, d4, d2[2] \n" + "vmlal.s16 q10, d4, d2[3] \n" + "vmlal.s16 q5, d5, d2[0] \n" + "vmlal.s16 q7, d5, d2[1] \n" + "vmlal.s16 q9, d5, d2[2] \n" + "vmlal.s16 q11, d5, d2[3] \n" + "vmlal.s16 q12, d4, d3[0] \n" + "vmlal.s16 q14, d4, d3[1] \n" + "vmlal.s16 q13, d5, d3[0] \n" + "vmlal.s16 q15, d5, d3[1] \n" + "vmlal.s16 q4, d6, d3[2] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q6, d6, d3[3] \n" + "vmlal.s16 q5, d7, d3[2] \n" + "vmlal.s16 q7, d7, d3[3] \n" + "vmlal.s16 q8, d6, d0[0] \n" + "pld [%2, #384] \n" + "vld1.s16 {d2-d3}, [%2]! \n" + "vmlal.s16 q10, d6, d0[1] \n" + "vmlal.s16 q12, d6, d0[2] \n" + "vmlal.s16 q14, d6, d0[3] \n" + "vmlal.s16 q9, d7, d0[0] \n" + "vmlal.s16 q11, d7, d0[1] \n" + "vmlal.s16 q13, d7, d0[2] \n" + "vmlal.s16 q15, d7, d0[3] \n" + "vmlal.s16 q4, d4, d1[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q6, d4, d1[1] \n" + "vmlal.s16 q8, d4, d1[2] \n" + "vmlal.s16 q10, d4, d1[3] \n" + "vmlal.s16 q5, d5, d1[0] \n" + "vmlal.s16 q7, d5, d1[1] \n" + "vmlal.s16 q9, d5, d1[2] \n" + "vmlal.s16 q11, d5, d1[3] \n" + "vmlal.s16 q12, d4, d2[0] \n" + "vld1.s16 {d0-d1}, [%2]! \n" + "vmlal.s16 q14, d4, d2[1] \n" + "vmlal.s16 q13, d5, d2[0] \n" + "vmlal.s16 q15, d5, d2[1] \n" + "vmlal.s16 q4, d6, d2[2] \n" + "pld [%1, #512] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q6, d6, d2[3] \n" + "vmlal.s16 q5, d7, d2[2] \n" + "vmlal.s16 q7, d7, d2[3] \n" + "vmlal.s16 q8, d6, d3[0] \n" + "vmlal.s16 q10, d6, d3[1] \n" + "vmlal.s16 q12, d6, d3[2] \n" + "vmlal.s16 q14, d6, d3[3] \n" + "vmlal.s16 q9, d7, d3[0] \n" + "vmlal.s16 q11, d7, d3[1] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q13, d7, d3[2] \n" + "vmlal.s16 q15, d7, d3[3] \n" + "bne 2b \n" + "sub %1, %1, #16 \n" + "sub %2, %2, #16 \n" + + "3: \n" + "and r4, %6, #7 \n" // w4 = remain = max_kk & 7 + "cmp r4, #0 \n" + "beq 5f \n" + + "4: \n" + "vld1.s16 {d0-d1}, [%1]! \n" + "vld1.s16 {d2-d3}, [%2] \n" + "add %2, %2, #12 \n" + "vmlal.s16 q4, d0, d2[0] \n" + "vmlal.s16 q6, d0, d2[1] \n" + "vmlal.s16 q8, d0, d2[2] \n" + "vmlal.s16 q10, d0, d2[3] \n" + "vmlal.s16 q5, d1, d2[0] \n" + "vmlal.s16 q7, d1, d2[1] \n" + "vmlal.s16 q9, d1, d2[2] \n" + "vmlal.s16 q11, d1, d2[3] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q12, d0, d3[0] \n" + "vmlal.s16 q14, d0, d3[1] \n" + "vmlal.s16 q13, d1, d3[0] \n" + "vmlal.s16 q15, d1, d3[1] \n" + "bne 4b \n" + + "5: \n" + "vstm %0!, {d8-d15} \n" + "vstm %0!, {d16-d31} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_pA), vget_low_s16(_pB), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_pA), vget_low_s16(_pB), 0); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_pA), vget_low_s16(_pB), 1); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_pA), vget_low_s16(_pB), 1); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_pA), vget_low_s16(_pB), 2); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_pA), vget_low_s16(_pB), 2); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_pA), vget_low_s16(_pB), 3); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_pA), vget_low_s16(_pB), 3); + _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_pA), vget_high_s16(_pB), 0); + _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_pA), vget_high_s16(_pB), 0); + _suma = vmlal_lane_s16(_suma, vget_low_s16(_pA), vget_high_s16(_pB), 1); + _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_pA), vget_high_s16(_pB), 1); + pA += 8; + pB += 6; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + outptr += 48; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #512] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + ".align 4 \n" + "2: \n" + "smlal v24.4s, v4.4h, v0.h[0] \n" + "smlal v26.4s, v4.4h, v0.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v25.4s, v4.8h, v0.h[0] \n" + "smlal2 v27.4s, v4.8h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal v28.4s, v4.4h, v0.h[2] \n" + "smlal v30.4s, v4.4h, v0.h[3] \n" + "smlal2 v29.4s, v4.8h, v0.h[2] \n" + "smlal2 v31.4s, v4.8h, v0.h[3] \n" + "smlal v24.4s, v5.4h, v0.h[4] \n" + "smlal v26.4s, v5.4h, v0.h[5] \n" + "smlal2 v25.4s, v5.8h, v0.h[4] \n" + "smlal2 v27.4s, v5.8h, v0.h[5] \n" + "smlal v28.4s, v5.4h, v0.h[6] \n" + "smlal v30.4s, v5.4h, v0.h[7] \n" + "smlal2 v29.4s, v5.8h, v0.h[6] \n" + "smlal2 v31.4s, v5.8h, v0.h[7] \n" + "smlal v24.4s, v6.4h, v1.h[0] \n" + "smlal v26.4s, v6.4h, v1.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v25.4s, v6.8h, v1.h[0] \n" + "smlal2 v27.4s, v6.8h, v1.h[1] \n" + "smlal v28.4s, v6.4h, v1.h[2] \n" + "smlal v30.4s, v6.4h, v1.h[3] \n" + "smlal2 v29.4s, v6.8h, v1.h[2] \n" + "smlal2 v31.4s, v6.8h, v1.h[3] \n" + "smlal v24.4s, v7.4h, v1.h[4] \n" + "smlal v26.4s, v7.4h, v1.h[5] \n" + "smlal2 v25.4s, v7.8h, v1.h[4] \n" + "smlal2 v27.4s, v7.8h, v1.h[5] \n" + "smlal v28.4s, v7.4h, v1.h[6] \n" + "smlal v30.4s, v7.4h, v1.h[7] \n" + "smlal2 v29.4s, v7.8h, v1.h[6] \n" + "smlal2 v31.4s, v7.8h, v1.h[7] \n" + "smlal v24.4s, v4.4h, v2.h[0] \n" + "smlal v26.4s, v4.4h, v2.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v25.4s, v4.8h, v2.h[0] \n" + "smlal2 v27.4s, v4.8h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal v28.4s, v4.4h, v2.h[2] \n" + "smlal v30.4s, v4.4h, v2.h[3] \n" + "smlal2 v29.4s, v4.8h, v2.h[2] \n" + "smlal2 v31.4s, v4.8h, v2.h[3] \n" + "smlal v24.4s, v5.4h, v2.h[4] \n" + "smlal v26.4s, v5.4h, v2.h[5] \n" + "smlal2 v25.4s, v5.8h, v2.h[4] \n" + "smlal2 v27.4s, v5.8h, v2.h[5] \n" + "smlal v28.4s, v5.4h, v2.h[6] \n" + "smlal v30.4s, v5.4h, v2.h[7] \n" + "smlal2 v29.4s, v5.8h, v2.h[6] \n" + "smlal2 v31.4s, v5.8h, v2.h[7] \n" + "smlal v24.4s, v6.4h, v3.h[0] \n" + "smlal v26.4s, v6.4h, v3.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v25.4s, v6.8h, v3.h[0] \n" + "smlal2 v27.4s, v6.8h, v3.h[1] \n" + "smlal v28.4s, v6.4h, v3.h[2] \n" + "smlal v30.4s, v6.4h, v3.h[3] \n" + "smlal2 v29.4s, v6.8h, v3.h[2] \n" + "smlal2 v31.4s, v6.8h, v3.h[3] \n" + "smlal v24.4s, v7.4h, v3.h[4] \n" + "smlal v26.4s, v7.4h, v3.h[5] \n" + "smlal2 v25.4s, v7.8h, v3.h[4] \n" + "smlal2 v27.4s, v7.8h, v3.h[5] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v7.4h, v3.h[6] \n" + "smlal v30.4s, v7.4h, v3.h[7] \n" + "smlal2 v29.4s, v7.8h, v3.h[6] \n" + "smlal2 v31.4s, v7.8h, v3.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #32 \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.4h}, [%2], #8 \n" + "smlal v24.4s, v4.4h, v0.h[0] \n" + "smlal v26.4s, v4.4h, v0.h[1] \n" + "smlal2 v25.4s, v4.8h, v0.h[0] \n" + "smlal2 v27.4s, v4.8h, v0.h[1] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v4.4h, v0.h[2] \n" + "smlal v30.4s, v4.4h, v0.h[3] \n" + "smlal2 v29.4s, v4.8h, v0.h[2] \n" + "smlal2 v31.4s, v4.8h, v0.h[3] \n" + "bne 4b \n" + + "5: \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "pld [%1, #512] \n" + "pld [%2, #256] \n" + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0, {d16-d31} \n" + "b 1f \n" + + "0: \n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + "vld1.s16 {d4-d5}, [%1]! \n" + "vld1.s16 {d0-d1}, [%2]! \n" + ".align 4 \n" + "2: \n" + "vmlal.s16 q8, d4, d0[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q10, d4, d0[1] \n" + "vmlal.s16 q12, d4, d0[2] \n" + "vmlal.s16 q14, d4, d0[3] \n" + "vmlal.s16 q9, d5, d0[0] \n" + "vld1.s16 {d8-d9}, [%1]! \n" + "vmlal.s16 q11, d5, d0[1] \n" + "vld1.s16 {d2-d3}, [%2]! \n" + "vmlal.s16 q13, d5, d0[2] \n" + "vmlal.s16 q15, d5, d0[3] \n" + "vmlal.s16 q8, d6, d1[0] \n" + "vmlal.s16 q10, d6, d1[1] \n" + "vmlal.s16 q12, d6, d1[2] \n" + "vmlal.s16 q14, d6, d1[3] \n" + "vmlal.s16 q9, d7, d1[0] \n" + "vld1.s16 {d10-d11}, [%1]! \n" + "vmlal.s16 q11, d7, d1[1] \n" + "vmlal.s16 q13, d7, d1[2] \n" + "vmlal.s16 q15, d7, d1[3] \n" + "vmlal.s16 q8, d8, d2[0] \n" + "vmlal.s16 q10, d8, d2[1] \n" + "vmlal.s16 q12, d8, d2[2] \n" + "vmlal.s16 q14, d8, d2[3] \n" + "vmlal.s16 q9, d9, d2[0] \n" + "pld [%1, #512] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q11, d9, d2[1] \n" + "pld [%2, #256] \n" + "vld1.s16 {d0-d1}, [%2]! \n" + "vmlal.s16 q13, d9, d2[2] \n" + "vmlal.s16 q15, d9, d2[3] \n" + "vmlal.s16 q8, d10, d3[0] \n" + "vmlal.s16 q10, d10, d3[1] \n" + "vmlal.s16 q12, d10, d3[2] \n" + "vmlal.s16 q14, d10, d3[3] \n" + "vmlal.s16 q9, d11, d3[0] \n" + "vmlal.s16 q11, d11, d3[1] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q13, d11, d3[2] \n" + "vmlal.s16 q15, d11, d3[3] \n" + "bne 2b \n" + "sub %1, %1, #16 \n" + "sub %2, %2, #16 \n" + + "3: \n" + "and r4, %6, #3 \n" // w4 = remain = max_kk & 3 + "cmp r4, #0 \n" + "beq 5f \n" + + "4: \n" + "vld1.s16 {d0-d1}, [%1]! \n" + "vld1.s16 {d2}, [%2]! \n" + "vmlal.s16 q8, d0, d2[0] \n" + "vmlal.s16 q10, d0, d2[1] \n" + "vmlal.s16 q12, d0, d2[2] \n" + "vmlal.s16 q14, d0, d2[3] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q9, d1, d2[0] \n" + "vmlal.s16 q11, d1, d2[1] \n" + "vmlal.s16 q13, d1, d2[2] \n" + "vmlal.s16 q15, d1, d2[3] \n" + "bne 4b \n" + + "5: \n" + "vstm %0!, {d16-d31} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x4_t _pB = vld1_s16(pB); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_pA), _pB, 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_pA), _pB, 0); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_pA), _pB, 1); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_pA), _pB, 1); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_pA), _pB, 2); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_pA), _pB, 2); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_pA), _pB, 3); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_pA), _pB, 3); + pA += 8; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + outptr += 32; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #256] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "b 1f \n" + + "0: \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v0.8h}, [%2], #16 \n" + ".align 4 \n" + "2: \n" + "smlal v28.4s, v4.4h, v0.h[0] \n" + "smlal v30.4s, v4.4h, v0.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v29.4s, v4.8h, v0.h[0] \n" + "smlal2 v31.4s, v4.8h, v0.h[1] \n" + "ld1 {v1.8h}, [%2], #16 \n" + "smlal v28.4s, v5.4h, v0.h[2] \n" + "smlal v30.4s, v5.4h, v0.h[3] \n" + "smlal2 v29.4s, v5.8h, v0.h[2] \n" + "smlal2 v31.4s, v5.8h, v0.h[3] \n" + "smlal v28.4s, v6.4h, v0.h[4] \n" + "smlal v30.4s, v6.4h, v0.h[5] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v29.4s, v6.8h, v0.h[4] \n" + "smlal2 v31.4s, v6.8h, v0.h[5] \n" + "smlal v28.4s, v7.4h, v0.h[6] \n" + "smlal v30.4s, v7.4h, v0.h[7] \n" + "smlal2 v29.4s, v7.8h, v0.h[6] \n" + "smlal2 v31.4s, v7.8h, v0.h[7] \n" + "smlal v28.4s, v4.4h, v1.h[0] \n" + "smlal v30.4s, v4.4h, v1.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v29.4s, v4.8h, v1.h[0] \n" + "smlal2 v31.4s, v4.8h, v1.h[1] \n" + "prfm pldl1keep, [%2, #256] \n" + "ld1 {v0.8h}, [%2], #16 \n" + "smlal v28.4s, v5.4h, v1.h[2] \n" + "smlal v30.4s, v5.4h, v1.h[3] \n" + "smlal2 v29.4s, v5.8h, v1.h[2] \n" + "smlal2 v31.4s, v5.8h, v1.h[3] \n" + "smlal v28.4s, v6.4h, v1.h[4] \n" + "smlal v30.4s, v6.4h, v1.h[5] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v29.4s, v6.8h, v1.h[4] \n" + "smlal2 v31.4s, v6.8h, v1.h[5] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v7.4h, v1.h[6] \n" + "smlal v30.4s, v7.4h, v1.h[7] \n" + "smlal2 v29.4s, v7.8h, v1.h[6] \n" + "smlal2 v31.4s, v7.8h, v1.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #16 \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.4h}, [%2] \n" + "add %2, %2, #4 \n" + "smlal v28.4s, v4.4h, v0.h[0] \n" + "smlal v30.4s, v4.4h, v0.h[1] \n" + "subs w4, w4, #1 \n" + "smlal2 v29.4s, v4.8h, v0.h[0] \n" + "smlal2 v31.4s, v4.8h, v0.h[1] \n" + "bne 4b \n" + + "5: \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "pld [%1, #512] \n" + "pld [%2, #128] \n" + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0, {d24-d31} \n" + "b 1f \n" + + "0: \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + "vld1.s16 {d2-d5}, [%1]! \n" + "vld1.s16 {d0}, [%2]! \n" + ".align 4 \n" + "2: \n" + "vmlal.s16 q12, d2, d0[0] \n" + "vld1.s16 {d6-d9}, [%1]! \n" + "vmlal.s16 q14, d2, d0[1] \n" + "vld1.s16 {d1}, [%2]! \n" + "vmlal.s16 q13, d3, d0[0] \n" + "vmlal.s16 q15, d3, d0[1] \n" + "vmlal.s16 q12, d4, d0[2] \n" + "vmlal.s16 q14, d4, d0[3] \n" + "vmlal.s16 q13, d5, d0[2] \n" + "vmlal.s16 q15, d5, d0[3] \n" + "vmlal.s16 q12, d6, d1[0] \n" + "pld [%1, #512] \n" + "vld1.s16 {d2-d5}, [%1]! \n" + "vmlal.s16 q14, d6, d1[1] \n" + "pld [%2, #128] \n" + "vld1.s16 {d0}, [%2]! \n" + "vmlal.s16 q13, d7, d1[0] \n" + "vmlal.s16 q15, d7, d1[1] \n" + "vmlal.s16 q12, d8, d1[2] \n" + "vmlal.s16 q14, d8, d1[3] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q13, d9, d1[2] \n" + "vmlal.s16 q15, d9, d1[3] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #8 \n" + + "3: \n" + "and r4, %6, #3 \n" // w4 = remain = max_kk & 3 + "cmp r4, #0 \n" + "beq 5f \n" + + "4: \n" + "vld1.s16 {d0-d1}, [%1]! \n" + "vld1.s16 {d2}, [%2] \n" + "add %2, %2, #4 \n" + "vmlal.s16 q12, d0, d2[0] \n" + "vmlal.s16 q14, d0, d2[1] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q13, d1, d2[0] \n" + "vmlal.s16 q15, d1, d2[1] \n" + "bne 4b \n" + + "5: \n" + "vstm %0!, {d24-d31} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x4_t _pB0 = vdup_n_s16(pB[0]); + int16x4_t _pB1 = vdup_n_s16(pB[1]); + _sum0 = vmlal_s16(_sum0, vget_low_s16(_pA), _pB0); + _sum1 = vmlal_s16(_sum1, vget_high_s16(_pA), _pB0); + _sum2 = vmlal_s16(_sum2, vget_low_s16(_pA), _pB1); + _sum3 = vmlal_s16(_sum3, vget_high_s16(_pA), _pB1); + pA += 8; + pB += 2; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + outptr += 16; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #128] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v30.4s, v31.4s}, [%0] \n" + "b 1f \n" + + "0: \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v1.8h}, [%2], #16 \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + ".align 4 \n" + "2: \n" + "mov v0.16b, v1.16b \n" + "smlal v28.4s, v4.4h, v0.h[0] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v29.4s, v4.8h, v0.h[0] \n" + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v1.8h}, [%2], #16 \n" + "smlal v30.4s, v5.4h, v0.h[1] \n" + "smlal2 v31.4s, v5.8h, v0.h[1] \n" + "smlal v28.4s, v6.4h, v0.h[2] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v29.4s, v6.8h, v0.h[2] \n" + "smlal v30.4s, v7.4h, v0.h[3] \n" + "smlal2 v31.4s, v7.8h, v0.h[3] \n" + "smlal v28.4s, v4.4h, v0.h[4] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v29.4s, v4.8h, v0.h[4] \n" + "smlal v30.4s, v5.4h, v0.h[5] \n" + "smlal2 v31.4s, v5.8h, v0.h[5] \n" + "smlal v28.4s, v6.4h, v0.h[6] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v29.4s, v6.8h, v0.h[6] \n" + "subs w4, w4, #1 \n" + "smlal v30.4s, v7.4h, v0.h[7] \n" + "smlal2 v31.4s, v7.8h, v0.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #16 \n" + "add v30.4s, v30.4s, v28.4s \n" + "add v31.4s, v31.4s, v29.4s \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1r {v0.4h}, [%2], #2 \n" + "subs w4, w4, #1 \n" + "smlal v30.4s, v4.4h, v0.h[0] \n" + "smlal2 v31.4s, v4.8h, v0.h[0] \n" + "bne 4b \n" + + "5: \n" + "st1 {v30.4s, v31.4s}, [%0], #32 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "pld [%1, #512] \n" + "pld [%2, #64] \n" + "cmp %7, #0 \n" + "beq 0f \n" + + "vld1.s32 {d28-d31}, [%0] \n" + "b 1f \n" + + "0: \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + "vld1.s16 {d2-d5}, [%1]! \n" + ".align 4 \n" + "2: \n" + "pld [%2, #64] \n" + "vld1.s16 {d0}, [%2]! \n" + "vmlal.s16 q14, d2, d0[0] \n" + "vld1.s16 {d6-d9}, [%1]! \n" + "vmlal.s16 q15, d3, d0[0] \n" + "vmlal.s16 q14, d4, d0[1] \n" + "vmlal.s16 q15, d5, d0[1] \n" + "vmlal.s16 q14, d6, d0[2] \n" + "pld [%1, #512] \n" + "vld1.s16 {d2-d5}, [%1]! \n" + "vmlal.s16 q15, d7, d0[2] \n" + "vmlal.s16 q14, d8, d0[3] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q15, d9, d0[3] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + + "3: \n" + "and r4, %6, #3 \n" // w4 = remain = max_kk & 3 + "cmp r4, #0 \n" + "beq 5f \n" + + "4: \n" + "vld1.s16 {d0-d1}, [%1]! \n" + "vld1.s16 {d2[]}, [%2]! \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q14, d0, d2[0] \n" + "vmlal.s16 q15, d1, d2[0] \n" + "bne 4b \n" + + "5: \n" + "vst1.s32 {d28-d31}, [%0]! \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x4_t _pB = vld1_dup_s16(pB); + _sum0 = vmlal_s16(_sum0, vget_low_s16(_pA), _pB); + _sum1 = vmlal_s16(_sum1, vget_high_s16(_pA), _pB); + pA += 8; + pB += 1; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + outptr += 8; +#endif // NCNN_GNU_INLINE_ASM + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vld1_s16(pB + 8); + _sum0 = vmlal_laneq_s16(_sum0, _pA, _pB, 0); + _sum1 = vmlal_laneq_s16(_sum1, _pA, _pB, 1); + _sum2 = vmlal_laneq_s16(_sum2, _pA, _pB, 2); + _sum3 = vmlal_laneq_s16(_sum3, _pA, _pB, 3); + _sum4 = vmlal_laneq_s16(_sum4, _pA, _pB, 4); + _sum5 = vmlal_laneq_s16(_sum5, _pA, _pB, 5); + _sum6 = vmlal_laneq_s16(_sum6, _pA, _pB, 6); + _sum7 = vmlal_laneq_s16(_sum7, _pA, _pB, 7); + _sum8 = vmlal_lane_s16(_sum8, _pA, _pB2, 0); + _sum9 = vmlal_lane_s16(_sum9, _pA, _pB2, 1); + _suma = vmlal_lane_s16(_suma, _pA, _pB2, 2); + _sumb = vmlal_lane_s16(_sumb, _pA, _pB2, 3); + pA += 4; + pB += 12; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + outptr += 48; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_laneq_s16(_sum0, _pA, _pB, 0); + _sum1 = vmlal_laneq_s16(_sum1, _pA, _pB, 1); + _sum2 = vmlal_laneq_s16(_sum2, _pA, _pB, 2); + _sum3 = vmlal_laneq_s16(_sum3, _pA, _pB, 3); + _sum4 = vmlal_laneq_s16(_sum4, _pA, _pB, 4); + _sum5 = vmlal_laneq_s16(_sum5, _pA, _pB, 5); + _sum6 = vmlal_laneq_s16(_sum6, _pA, _pB, 6); + _sum7 = vmlal_laneq_s16(_sum7, _pA, _pB, 7); + pA += 4; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + outptr += 32; + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_lane_s16(_sum0, _pA, vget_low_s16(_pB), 0); + _sum1 = vmlal_lane_s16(_sum1, _pA, vget_low_s16(_pB), 1); + _sum2 = vmlal_lane_s16(_sum2, _pA, vget_low_s16(_pB), 2); + _sum3 = vmlal_lane_s16(_sum3, _pA, vget_low_s16(_pB), 3); + _sum4 = vmlal_lane_s16(_sum4, _pA, vget_high_s16(_pB), 0); + _sum5 = vmlal_lane_s16(_sum5, _pA, vget_high_s16(_pB), 1); + pA += 4; + pB += 6; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + outptr += 24; + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x4_t _pB = vld1_s16(pB); + _sum0 = vmlal_lane_s16(_sum0, _pA, _pB, 0); + _sum1 = vmlal_lane_s16(_sum1, _pA, _pB, 1); + _sum2 = vmlal_lane_s16(_sum2, _pA, _pB, 2); + _sum3 = vmlal_lane_s16(_sum3, _pA, _pB, 3); + pA += 4; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x4_t _pB0 = vdup_n_s16(pB[0]); + int16x4_t _pB1 = vdup_n_s16(pB[1]); + _sum0 = vmlal_s16(_sum0, _pA, _pB0); + _sum1 = vmlal_s16(_sum1, _pA, _pB1); + pA += 4; + pB += 2; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + outptr += 8; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int32x4_t _sum0; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x4_t _pB = vld1_dup_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA, _pB); + pA += 4; + pB += 1; + } + + vst1q_s32(outptr, _sum0); + outptr += 4; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + } + else + { + int32x4x2_t _s01 = vld2q_s32(outptr); + int32x4x2_t _s23 = vld2q_s32(outptr + 8); + int32x4x2_t _s45 = vld2q_s32(outptr + 16); + _sum0 = _s01.val[0]; + _sum3 = _s01.val[1]; + _sum1 = _s23.val[0]; + _sum4 = _s23.val[1]; + _sum2 = _s45.val[0]; + _sum5 = _s45.val[1]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA0 = vdup_n_s16(pA[0]); + int16x4_t _pA1 = vdup_n_s16(pA[1]); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vld1_s16(pB + 8); + _sum0 = vmlal_s16(_sum0, _pA0, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA0, vget_high_s16(_pB)); + _sum2 = vmlal_s16(_sum2, _pA0, _pB2); + _sum3 = vmlal_s16(_sum3, _pA1, vget_low_s16(_pB)); + _sum4 = vmlal_s16(_sum4, _pA1, vget_high_s16(_pB)); + _sum5 = vmlal_s16(_sum5, _pA1, _pB2); + pA += 2; + pB += 12; + } + + int32x4x2_t _s01; + _s01.val[0] = _sum0; + _s01.val[1] = _sum3; + int32x4x2_t _s23; + _s23.val[0] = _sum1; + _s23.val[1] = _sum4; + int32x4x2_t _s45; + _s45.val[0] = _sum2; + _s45.val[1] = _sum5; + vst2q_s32(outptr, _s01); + vst2q_s32(outptr + 8, _s23); + vst2q_s32(outptr + 16, _s45); + outptr += 24; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + int32x4x2_t _s01 = vld2q_s32(outptr); + int32x4x2_t _s23 = vld2q_s32(outptr + 8); + _sum0 = _s01.val[0]; + _sum2 = _s01.val[1]; + _sum1 = _s23.val[0]; + _sum3 = _s23.val[1]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA0 = vdup_n_s16(pA[0]); + int16x4_t _pA1 = vdup_n_s16(pA[1]); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA0, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA0, vget_high_s16(_pB)); + _sum2 = vmlal_s16(_sum2, _pA1, vget_low_s16(_pB)); + _sum3 = vmlal_s16(_sum3, _pA1, vget_high_s16(_pB)); + pA += 2; + pB += 8; + } + + int32x4x2_t _s01; + _s01.val[0] = _sum0; + _s01.val[1] = _sum2; + int32x4x2_t _s23; + _s23.val[0] = _sum1; + _s23.val[1] = _sum3; + vst2q_s32(outptr, _s01); + vst2q_s32(outptr + 8, _s23); + outptr += 16; + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + } + else + { + int32x4x2_t _s01 = vld2q_s32(outptr); + _sum0 = _s01.val[0]; + _sum1 = _s01.val[1]; + _sum2 = vld1q_s32(outptr + 8); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vreinterpret_s16_s32(vld1_dup_s32((const int*)pA)); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vzip_s16(vget_high_s16(_pB), vget_high_s16(_pB)).val[0]; + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_pB), _pA, 0); + _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_pB), _pA, 1); + _sum2 = vmlal_s16(_sum2, _pA, _pB2); + pA += 2; + pB += 6; + } + + int32x4x2_t _s01; + _s01.val[0] = _sum0; + _s01.val[1] = _sum1; + vst2q_s32(outptr, _s01); + vst1q_s32(outptr + 8, _sum2); + outptr += 12; + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + int32x4x2_t _s01 = vld2q_s32(outptr); + _sum0 = _s01.val[0]; + _sum1 = _s01.val[1]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA0 = vdup_n_s16(pA[0]); + int16x4_t _pA1 = vdup_n_s16(pA[1]); + int16x4_t _pB = vld1_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA0, _pB); + _sum1 = vmlal_s16(_sum1, _pA1, _pB); + pA += 2; + pB += 4; + } + + int32x4x2_t _s01; + _s01.val[0] = _sum0; + _s01.val[1] = _sum1; + vst2q_s32(outptr, _s01); + outptr += 8; + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum00 = 0; + int sum01 = 0; + int sum10 = 0; + int sum11 = 0; + + if (k == 0) + { + sum00 = 0; + sum01 = 0; + sum10 = 0; + sum11 = 0; + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + // fomit-frame-pointer implied in optimized flag spare one register + // let us stay away from error: ‘asm’ operand has impossible constraints --- nihui +#if __OPTIMIZE__ + asm volatile( + "ldr r2, [%0], #4 \n" // int16x2_t _pA0 = *((int16x2_t*)pA); pA += 2; + "ldr r3, [%0], #4 \n" // int16x2_t _pA1 = *((int16x2_t*)pA); pA += 2; + "ldr r4, [%1], #4 \n" // int16x2_t _pB0 = *((int16x2_t*)pB); pB += 2; + "ldr r5, [%1], #4 \n" // int16x2_t _pB1 = *((int16x2_t*)pB); pB += 2; + "smlad %2, r2, r4, %2 \n" // sum00 = __smlad(_pA0, _pB0, sum00); + "smlad %3, r3, r4, %3 \n" // sum01 = __smlad(_pA1, _pB0, sum01); + "smlad %4, r2, r5, %4 \n" // sum10 = __smlad(_pA0, _pB1, sum10); + "smlad %5, r3, r5, %5 \n" // sum11 = __smlad(_pA1, _pB1, sum11); + : "=r"(pA), + "=r"(pB), + "=r"(sum00), + "=r"(sum01), + "=r"(sum10), + "=r"(sum11) + : "0"(pA), + "1"(pB), + "2"(sum00), + "3"(sum01), + "4"(sum10), + "5"(sum11) + : "memory", "r2", "r3", "r4", "r5"); +#else + int _pA0 = *((int*)pA); + int _pA1 = *((int*)(pA + 2)); + int _pB0 = *((int*)pB); + int _pB1 = *((int*)(pB + 2)); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum00) + : "0"(sum00), "r"(_pA0), "r"(_pB0) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum01) + : "0"(sum01), "r"(_pA1), "r"(_pB0) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum10) + : "0"(sum10), "r"(_pA0), "r"(_pB1) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum11) + : "0"(sum11), "r"(_pA1), "r"(_pB1) + :); + pA += 4; + pB += 4; +#endif + } +#endif // !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk++) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + outptr += 2 * 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + asm volatile( + "ldr r2, [%0], #4 \n" // int16x2_t _pA0 = *((int16x2_t*)pA); pA += 2; + "ldr r3, [%0], #4 \n" // int16x2_t _pA1 = *((int16x2_t*)pA); pA += 2; + "ldr r4, [%1], #4 \n" // int16x2_t _pB = *((int16x2_t*)pB); pB += 2; + "smlad %2, r2, r4, %2 \n" // sum0 = __smlad(_pA0, _pB, sum0); + "smlad %3, r3, r4, %3 \n" // sum1 = __smlad(_pA1, _pB, sum1); + : "=r"(pA), + "=r"(pB), + "=r"(sum0), + "=r"(sum1) + : "0"(pA), + "1"(pB), + "2"(sum0), + "3"(sum1) + : "memory", "r2", "r3", "r4"); + } +#endif // !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + } + } + for (; ii < max_ii; ii++) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_dup_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vld1_s16(pB + 8); + _sum0 = vmlal_s16(_sum0, _pA, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA, vget_high_s16(_pB)); + _sum2 = vmlal_s16(_sum2, _pA, _pB2); + pA += 1; + pB += 12; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + outptr += 12; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_dup_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA, vget_high_s16(_pB)); + pA += 1; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + outptr += 8; + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_dup_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA, vget_high_s16(_pB)); + pA += 1; + pB += 6; + } + + vst1q_s32(outptr, _sum0); + vst1_s32(outptr + 4, vget_low_s32(_sum1)); + outptr += 6; + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + int32x4_t _sum0; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_dup_s16(pA); + int16x4_t _pB = vld1_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA, _pB); + pA += 1; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + outptr += 4; + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + asm volatile( + "ldr r2, [%0], #4 \n" // int16x2_t _pA = *((int16x2_t*)pA); pA += 2; + "ldr r3, [%1], #4 \n" // int16x2_t _pB0 = *((int16x2_t*)pB); pB += 2; + "ldr r4, [%1], #4 \n" // int16x2_t _pB1 = *((int16x2_t*)pB); pB += 2; + "smlad %2, r2, r3, %2 \n" // sum0 = __smlad(_pA, _pB0, sum0); + "smlad %3, r2, r4, %3 \n" // sum1 = __smlad(_pA, _pB1, sum1); + : "=r"(pA), + "=r"(pB), + "=r"(sum0), + "=r"(sum1) + : "0"(pA), + "1"(pB), + "2"(sum0), + "3"(sum1) + : "memory", "r2", "r3", "r4"); + } +#endif // !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum = 0; + + if (k == 0) + { + sum = 0; + } + else + { + sum = outptr[0]; + } + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + asm volatile( + "ldr r2, [%0], #4 \n" // int16x2_t _pA = *((int16x2_t*)pA); pA += 2; + "ldr r3, [%1], #4 \n" // int16x2_t _pB = *((int16x2_t*)pB); pB += 2; + "smlad %2, r2, r3, %2 \n" // sum = __smlad(_pA, _pB, sum); + : "=r"(pA), + "=r"(pB), + "=r"(sum) + : "0"(pA), + "1"(pB), + "2"(sum) + : "memory", "r2", "r3"); + } +#endif // !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk++) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + outptr[0] = sum; + outptr += 1; + } + } + } +} + +static void get_optimal_tile_mnk_int8(int M, int N, int K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const int l2_cache_size_int8 = (int)(get_cpu_level2_cache_size() / sizeof(short)); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + // we shall take B into account for batched gemm, but that will be slower on arm in practice, why ? + // (void)B; + + // solve K + { + // try not to split K +#if __aarch64__ + int tile_size = (l2_cache_size_int8 - 32) / 12; +#elif __ARM_NEON + int tile_size = (l2_cache_size_int8 - 32) / 6; +#else + int tile_size = (l2_cache_size_int8 - 2) / 3; +#endif + +#if __aarch64__ + TILE_K = std::max(8, tile_size / 8 * 8); +#elif __ARM_NEON + TILE_K = std::max(4, tile_size / 4 * 4); +#else + TILE_K = std::max(2, tile_size / 2 * 2); +#endif + + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __aarch64__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __ARM_NEON + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); +#endif + } + + // solve M + { +#if __ARM_NEON + TILE_M = 8; +#else + TILE_M = 2; +#endif + } + + { + TILE_M *= std::min(nT, get_physical_cpu_count()); + + int nn_M = (M + TILE_M - 1) / TILE_M; +#if __ARM_NEON + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#else + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + + if (nT > 1) + { +#if __ARM_NEON + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); +#else + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); +#endif + } + +#if __ARM_NEON + TILE_M = std::max(8, TILE_M); +#else + TILE_M = std::max(2, TILE_M); +#endif + } + + if (N > 0) + { + int tile_size; + if (TILE_K >= K) + { + tile_size = (l2_cache_size_int8 - TILE_M * TILE_K) / TILE_K; + } + else + { + tile_size = (l2_cache_size_int8 - TILE_M * TILE_K) / (TILE_M * 2 + TILE_K); + } + +#if __aarch64__ + TILE_N = std::max(4, tile_size / 4 * 4); +#elif __ARM_NEON + TILE_N = std::max(4, tile_size / 4 * 4); +#else + TILE_N = std::max(1, tile_size); +#endif + + int nn_N = (N + TILE_N - 1) / TILE_N; + +#if __aarch64__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#elif __ARM_NEON + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif + +#if __aarch64__ + TILE_N = std::max(4, TILE_N); +#elif __ARM_NEON + TILE_N = std::max(4, TILE_N); +#else + TILE_N = std::max(1, TILE_N); +#endif + } +} + +static inline void conv3x3s1_winograd23_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const signed char ktm[4][3] = { + // {2, 0, 0}, + // {1, 1, 1}, + // {1, -1, 1}, + // {0, 0, 2} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[4][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 2; + tmp[1][m] = r0 + r1 + r2; + tmp[2][m] = r0 - r1 + r2; + tmp[3][m] = r2 * 2; + + k0 += 3; + } + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 2; + short z1 = r0 + r1 + r2; + short z2 = r0 - r1 + r2; + short z3 = r2 * 2; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp += 4; + } + } + } +} + +static void conv3x3s1_winograd23_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ + const int M = outch; + const int K = inch; + const int B = 16; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 2u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd23_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd23_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const signed char itm[4][4] = { + // {1, 0, -1, 0}, + // {0, 1, 1, 0}, + // {0, -1, 1, 0}, + // {0, -1, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w - 1) / 2; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __ARM_NEON + nn_max_kk = max_kk / 8; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + + short tmp[4][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; + + for (int m = 0; m < 4; m++) + { + int8x8_t _r0 = vdup_n_s8(0); + int8x8_t _r1 = vdup_n_s8(0); + int8x8_t _r2 = vdup_n_s8(0); + int8x8_t _r3 = vdup_n_s8(0); + + if (ti * 2 + m < h) + { + if (elempack == 8) + { + _r0 = vld1_s8(r0); + if (tj * 2 + 1 < w) _r1 = vld1_s8(r0 + 8); + if (tj * 2 + 2 < w) _r2 = vld1_s8(r0 + 16); + if (tj * 2 + 3 < w) _r3 = vld1_s8(r0 + 24); + } + if (elempack == 1) + { + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + int8x8_t _t0 = vld1_s8(r0); + int8x8_t _t1 = vld1_s8(r1); + int8x8_t _t2 = vld1_s8(r2); + int8x8_t _t3 = vld1_s8(r3); + int8x8_t _t4 = vld1_s8(r4); + int8x8_t _t5 = vld1_s8(r5); + int8x8_t _t6 = vld1_s8(r6); + int8x8_t _t7 = vld1_s8(r7); + + int8x8_t _t01 = vzip_s8(_t0, _t1).val[0]; + int8x8_t _t23 = vzip_s8(_t2, _t3).val[0]; + int8x8_t _t45 = vzip_s8(_t4, _t5).val[0]; + int8x8_t _t67 = vzip_s8(_t6, _t7).val[0]; + int16x4x2_t _t0123 = vzip_s16(vreinterpret_s16_s8(_t01), vreinterpret_s16_s8(_t23)); + int16x4x2_t _t4567 = vzip_s16(vreinterpret_s16_s8(_t45), vreinterpret_s16_s8(_t67)); + int16x8_t _ta = vcombine_s16(_t0123.val[0], _t0123.val[1]); + int16x8_t _tb = vcombine_s16(_t4567.val[0], _t4567.val[1]); + int32x4x2_t _tab = vzipq_s32(vreinterpretq_s32_s16(_ta), vreinterpretq_s32_s16(_tb)); + + _r0 = vreinterpret_s8_s32(vget_low_s32(_tab.val[0])); + if (tj * 2 + 1 < w) _r1 = vreinterpret_s8_s32(vget_high_s32(_tab.val[0])); + if (tj * 2 + 2 < w) _r2 = vreinterpret_s8_s32(vget_low_s32(_tab.val[1])); + if (tj * 2 + 3 < w) _r3 = vreinterpret_s8_s32(vget_high_s32(_tab.val[1])); + } + } + + int16x8_t _tmp0 = vsubl_s8(_r0, _r2); + int16x8_t _tmp1 = vaddl_s8(_r1, _r2); + int16x8_t _tmp2 = vsubl_s8(_r2, _r1); + int16x8_t _tmp3 = vsubl_s8(_r3, _r1); + + vst1q_s16(tmp[0][m], _tmp0); + vst1q_s16(tmp[1][m], _tmp1); + vst1q_s16(tmp[2][m], _tmp2); + vst1q_s16(tmp[3][m], _tmp3); + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { + int16x8_t _r0 = vld1q_s16(tmp[m][0]); + int16x8_t _r1 = vld1q_s16(tmp[m][1]); + int16x8_t _r2 = vld1q_s16(tmp[m][2]); + int16x8_t _r3 = vld1q_s16(tmp[m][3]); + + int16x8_t _tmp0 = vsubq_s16(_r0, _r2); + int16x8_t _tmp1 = vaddq_s16(_r1, _r2); + int16x8_t _tmp2 = vsubq_s16(_r2, _r1); + int16x8_t _tmp3 = vsubq_s16(_r3, _r1); + + vst1q_s16(p0, _tmp0); + vst1q_s16(p1, _tmp1); + vst1q_s16(p2, _tmp2); + vst1q_s16(p3, _tmp3); + + p0 += max_jj * 4 * 8; + p1 += max_jj * 4 * 8; + p2 += max_jj * 4 * 8; + p3 += max_jj * 4 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __ARM_NEON + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __ARM_NEON + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[4][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 2 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 2 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 2 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + } + } + + tmp[0][m][0] = r00 - r20; + tmp[0][m][1] = r01 - r21; + tmp[1][m][0] = r10 + r20; + tmp[1][m][1] = r11 + r21; + tmp[2][m][0] = r20 - r10; + tmp[2][m][1] = r21 - r11; + tmp[3][m][0] = r30 - r10; + tmp[3][m][1] = r31 - r11; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + + p0[0] = r00 - r20; + p0[1] = r01 - r21; + p1[0] = r10 + r20; + p1[1] = r11 + r21; + p2[0] = r20 - r10; + p2[1] = r21 - r11; + p3[0] = r30 - r10; + p3[1] = r31 - r11; + + p0 += max_jj * 4 * 2; + p1 += max_jj * 4 * 2; + p2 += max_jj * 4 * 2; + p3 += max_jj * 4 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 2 + 1 < w) r1 = r0123[1]; + if (tj * 2 + 2 < w) r2 = r0123[2]; + if (tj * 2 + 3 < w) r3 = r0123[3]; + } + } + + tmp[0][m] = r0 - r2; + tmp[1][m] = r1 + r2; + tmp[2][m] = r2 - r1; + tmp[3][m] = r3 - r1; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + + p0[0] = r0 - r2; + p1[0] = r1 + r2; + p2[0] = r2 - r1; + p3[0] = r3 - r1; + + p0 += max_jj * 4; + p1 += max_jj * 4; + p2 += max_jj * 4; + p3 += max_jj * 4; + } + } + } +} + +static inline void conv3x3s1_winograd23_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[2][4] = { + // {1, 1, 1, 0}, + // {0, 1, -1, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 1) / 2; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + int tmp[2][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { + int32x4_t _r00 = vld1q_s32(r0); + int32x4_t _r01 = vld1q_s32(r0 + 4); + int32x4_t _r10 = vld1q_s32(r1); + int32x4_t _r11 = vld1q_s32(r1 + 4); + int32x4_t _r20 = vld1q_s32(r2); + int32x4_t _r21 = vld1q_s32(r2 + 4); + int32x4_t _r30 = vld1q_s32(r3); + int32x4_t _r31 = vld1q_s32(r3 + 4); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_r00, _r10), _r20); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_r01, _r11), _r21); + int32x4_t _tmp10 = vaddq_s32(vsubq_s32(_r10, _r20), _r30); + int32x4_t _tmp11 = vaddq_s32(vsubq_s32(_r11, _r21), _r31); + + vst1q_s32(tmp[0][m], _tmp00); + vst1q_s32(tmp[0][m] + 4, _tmp01); + vst1q_s32(tmp[1][m], _tmp10); + vst1q_s32(tmp[1][m] + 4, _tmp11); + + r0 += max_jj * 4 * 8; + r1 += max_jj * 4 * 8; + r2 += max_jj * 4 * 8; + r3 += max_jj * 4 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int32x4_t _r00 = vld1q_s32(tmp[m][0]); + int32x4_t _r01 = vld1q_s32(tmp[m][0] + 4); + int32x4_t _r10 = vld1q_s32(tmp[m][1]); + int32x4_t _r11 = vld1q_s32(tmp[m][1] + 4); + int32x4_t _r20 = vld1q_s32(tmp[m][2]); + int32x4_t _r21 = vld1q_s32(tmp[m][2] + 4); + int32x4_t _r30 = vld1q_s32(tmp[m][3]); + int32x4_t _r31 = vld1q_s32(tmp[m][3] + 4); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_r00, _r10), _r20); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_r01, _r11), _r21); + int32x4_t _tmp10 = vaddq_s32(vsubq_s32(_r10, _r20), _r30); + int32x4_t _tmp11 = vaddq_s32(vsubq_s32(_r11, _r21), _r31); + + _tmp00 = vshrq_n_s32(_tmp00, 2); + _tmp01 = vshrq_n_s32(_tmp01, 2); + _tmp10 = vshrq_n_s32(_tmp10, 2); + _tmp11 = vshrq_n_s32(_tmp11, 2); + + if (out_elempack == 8) + { + vst1q_s32(outptr0, _tmp00); + vst1q_s32(outptr0 + 4, _tmp01); + if (tj * 2 + 1 < outw) + { + vst1q_s32(outptr0 + 8, _tmp10); + vst1q_s32(outptr0 + 12, _tmp11); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + vst1q_s32(outptr0, _tmp00); + vst1q_s32(outptr1, _tmp01); + if (tj * 2 + 1 < outw) + { + vst1q_s32(outptr0 + 4, _tmp10); + vst1q_s32(outptr1 + 4, _tmp11); + } + } + if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = vgetq_lane_s32(_tmp00, 0); + outptr1[0] = vgetq_lane_s32(_tmp00, 1); + outptr2[0] = vgetq_lane_s32(_tmp00, 2); + outptr3[0] = vgetq_lane_s32(_tmp00, 3); + outptr4[0] = vgetq_lane_s32(_tmp01, 0); + outptr5[0] = vgetq_lane_s32(_tmp01, 1); + outptr6[0] = vgetq_lane_s32(_tmp01, 2); + outptr7[0] = vgetq_lane_s32(_tmp01, 3); + + if (tj * 2 + 1 < outw) + { + outptr0[1] = vgetq_lane_s32(_tmp10, 0); + outptr1[1] = vgetq_lane_s32(_tmp10, 1); + outptr2[1] = vgetq_lane_s32(_tmp10, 2); + outptr3[1] = vgetq_lane_s32(_tmp10, 3); + outptr4[1] = vgetq_lane_s32(_tmp11, 0); + outptr5[1] = vgetq_lane_s32(_tmp11, 1); + outptr6[1] = vgetq_lane_s32(_tmp11, 2); + outptr7[1] = vgetq_lane_s32(_tmp11, 3); + } + } + + outptr0 += outw * out_elempack; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + int tmp[2][4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + + for (int m = 0; m < 4; m++) + { + int32x4_t _r0 = vld1q_s32(r0); + int32x4_t _r1 = vld1q_s32(r1); + int32x4_t _r2 = vld1q_s32(r2); + int32x4_t _r3 = vld1q_s32(r3); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_r0, _r1), _r2); + int32x4_t _tmp1 = vaddq_s32(vsubq_s32(_r1, _r2), _r3); + + vst1q_s32(tmp[0][m], _tmp0); + vst1q_s32(tmp[1][m], _tmp1); + + r0 += max_jj * 4 * 4; + r1 += max_jj * 4 * 4; + r2 += max_jj * 4 * 4; + r3 += max_jj * 4 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int32x4_t _r0 = vld1q_s32(tmp[m][0]); + int32x4_t _r1 = vld1q_s32(tmp[m][1]); + int32x4_t _r2 = vld1q_s32(tmp[m][2]); + int32x4_t _r3 = vld1q_s32(tmp[m][3]); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_r0, _r1), _r2); + int32x4_t _tmp1 = vaddq_s32(vsubq_s32(_r1, _r2), _r3); + + _tmp0 = vshrq_n_s32(_tmp0, 2); + _tmp1 = vshrq_n_s32(_tmp1, 2); + + if (out_elempack == 4) + { + vst1q_s32(outptr0, _tmp0); + if (tj * 2 + 1 < outw) vst1q_s32(outptr0 + 4, _tmp1); + } + if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = vgetq_lane_s32(_tmp0, 0); + outptr1[0] = vgetq_lane_s32(_tmp0, 1); + outptr2[0] = vgetq_lane_s32(_tmp0, 2); + outptr3[0] = vgetq_lane_s32(_tmp0, 3); + + if (tj * 2 + 1 < outw) + { + outptr0[1] = vgetq_lane_s32(_tmp1, 0); + outptr1[1] = vgetq_lane_s32(_tmp1, 1); + outptr2[1] = vgetq_lane_s32(_tmp1, 2); + outptr3[1] = vgetq_lane_s32(_tmp1, 3); + } + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[2][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m][0] = r0[0] + r1[0] + r2[0]; + tmp[0][m][1] = r0[1] + r1[1] + r2[1]; + tmp[1][m][0] = r1[0] - r2[0] + r3[0]; + tmp[1][m][1] = r1[1] - r2[1] + r3[1]; + + r0 += max_jj * 4 * 2; + r1 += max_jj * 4 * 2; + r2 += max_jj * 4 * 2; + r3 += max_jj * 4 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp00 = tmp[m][0][0] + tmp[m][1][0] + tmp[m][2][0]; + int tmp01 = tmp[m][0][1] + tmp[m][1][1] + tmp[m][2][1]; + int tmp10 = tmp[m][1][0] - tmp[m][2][0] + tmp[m][3][0]; + int tmp11 = tmp[m][1][1] - tmp[m][2][1] + tmp[m][3][1]; + + tmp00 = tmp00 >> 2; + tmp01 = tmp01 >> 2; + tmp10 = tmp10 >> 2; + tmp11 = tmp11 >> 2; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[2][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m] = r0[0] + r1[0] + r2[0]; + tmp[1][m] = r1[0] - r2[0] + r3[0]; + + r0 += max_jj * 4; + r1 += max_jj * 4; + r2 += max_jj * 4; + r3 += max_jj * 4; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp0 = tmp[m][0] + tmp[m][1] + tmp[m][2]; + int tmp1 = tmp[m][1] - tmp[m][2] + tmp[m][3]; + + tmp0 = tmp0 >> 2; + tmp1 = tmp1 >> 2; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 2 + 1 < outw) outptr0[1] = tmp1; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd23_int8(Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 2n+2, winograd F(2,3) + int w_tiles = (outw + 1) / 2; + int h_tiles = (outh + 1) / 2; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 16; + + // NCNN_LOGE("conv3x3s1_winograd23_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 2u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 2u, opt.workspace_allocator); + + // #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + bottom_blob.release(); + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk); + } + + // transform output + conv3x3s1_winograd23_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} + +static inline void conv3x3s1_winograd43_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const short ktm[6][3] = { + // {6, 0, 0}, + // {-4, -4, -4}, + // {-4, 4, -4}, + // {1, 2, 4}, + // {1, -2, 4}, + // {0, 0, 6} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[6][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 6; + tmp[1][m] = -r0 * 4 - r1 * 4 - r2 * 4; + tmp[2][m] = -r0 * 4 + r1 * 4 - r2 * 4; + tmp[3][m] = r0 + r1 * 2 + r2 * 4; + tmp[4][m] = r0 - r1 * 2 + r2 * 4; + tmp[5][m] = r2 * 6; + + k0 += 3; + } + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 6; + short z1 = -r0 * 4 - r1 * 4 - r2 * 4; + short z2 = -r0 * 4 + r1 * 4 - r2 * 4; + short z3 = r0 + r1 * 2 + r2 * 4; + short z4 = r0 - r1 * 2 + r2 * 4; + short z5 = r2 * 6; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp[4] = z4; + ptmp[5] = z5; + ptmp += 6; + } + } + } +} + +static void conv3x3s1_winograd43_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ + const int M = outch; + const int K = inch; + const int B = 36; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 2u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd43_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd43_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const float itm[4][4] = { + // {4, 0, -5, 0, 1, 0}, + // {0, -4, -4, 1, 1, 0}, + // {0, 4, -4, -1, 1, 0}, + // {0, -2, -1, 2, 1, 0}, + // {0, 2, -1, -2, 1, 0}, + // {0, 4, 0, -5, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w + 1) / 4; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __ARM_NEON + nn_max_kk = max_kk / 8; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + + short tmp[6][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; + + int8x8_t _v5 = vdup_n_s8(5); + + for (int m = 0; m < 6; m++) + { + int8x8_t _r0 = vdup_n_s8(0); + int8x8_t _r1 = vdup_n_s8(0); + int8x8_t _r2 = vdup_n_s8(0); + int8x8_t _r3 = vdup_n_s8(0); + int8x8_t _r4 = vdup_n_s8(0); + int8x8_t _r5 = vdup_n_s8(0); + + if (ti * 4 + m < h) + { + if (elempack == 8) + { + _r0 = vld1_s8(r0); + if (tj * 4 + 1 < w) _r1 = vld1_s8(r0 + 8); + if (tj * 4 + 2 < w) _r2 = vld1_s8(r0 + 16); + if (tj * 4 + 3 < w) _r3 = vld1_s8(r0 + 24); + if (tj * 4 + 4 < w) _r4 = vld1_s8(r0 + 32); + if (tj * 4 + 5 < w) _r5 = vld1_s8(r0 + 40); + } + if (elempack == 1) + { + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + int8x8_t _t0 = vld1_s8(r0); + int8x8_t _t1 = vld1_s8(r1); + int8x8_t _t2 = vld1_s8(r2); + int8x8_t _t3 = vld1_s8(r3); + int8x8_t _t4 = vld1_s8(r4); + int8x8_t _t5 = vld1_s8(r5); + int8x8_t _t6 = vld1_s8(r6); + int8x8_t _t7 = vld1_s8(r7); + + int8x8_t _t01 = vzip_s8(_t0, _t1).val[0]; + int8x8_t _t23 = vzip_s8(_t2, _t3).val[0]; + int8x8_t _t45 = vzip_s8(_t4, _t5).val[0]; + int8x8_t _t67 = vzip_s8(_t6, _t7).val[0]; + int16x4x2_t _t0123 = vzip_s16(vreinterpret_s16_s8(_t01), vreinterpret_s16_s8(_t23)); + int16x4x2_t _t4567 = vzip_s16(vreinterpret_s16_s8(_t45), vreinterpret_s16_s8(_t67)); + int16x8_t _ta = vcombine_s16(_t0123.val[0], _t0123.val[1]); + int16x8_t _tb = vcombine_s16(_t4567.val[0], _t4567.val[1]); + int32x4x2_t _tab = vzipq_s32(vreinterpretq_s32_s16(_ta), vreinterpretq_s32_s16(_tb)); + + _r0 = vreinterpret_s8_s32(vget_low_s32(_tab.val[0])); + if (tj * 4 + 1 < w) _r1 = vreinterpret_s8_s32(vget_high_s32(_tab.val[0])); + if (tj * 4 + 2 < w) _r2 = vreinterpret_s8_s32(vget_low_s32(_tab.val[1])); + if (tj * 4 + 3 < w) _r3 = vreinterpret_s8_s32(vget_high_s32(_tab.val[1])); + if (tj * 4 + 4 < w) + { + _t01 = vzip_s8(_t0, _t1).val[1]; + _t23 = vzip_s8(_t2, _t3).val[1]; + _t45 = vzip_s8(_t4, _t5).val[1]; + _t67 = vzip_s8(_t6, _t7).val[1]; + int16x4_t _tc = vzip_s16(vreinterpret_s16_s8(_t01), vreinterpret_s16_s8(_t23)).val[0]; + int16x4_t _td = vzip_s16(vreinterpret_s16_s8(_t45), vreinterpret_s16_s8(_t67)).val[0]; + int32x2x2_t _tcd = vzip_s32(vreinterpret_s32_s16(_tc), vreinterpret_s32_s16(_td)); + + _r4 = vreinterpret_s8_s32(_tcd.val[0]); + if (tj * 4 + 5 < w) _r5 = vreinterpret_s8_s32(_tcd.val[1]); + } + } + } + + int16x8_t _tmp12a = vsubw_s8(vshll_n_s8(_r1, 2), _r3); + int16x8_t _tmp12b = vsubw_s8(vshll_n_s8(_r2, 2), _r4); + int16x8_t _tmp34a = vshlq_n_s16(vsubl_s8(_r3, _r1), 1); + int16x8_t _tmp34b = vsubl_s8(_r4, _r2); + + int16x8_t _tmp0 = vaddq_s16(vmovl_s8(_r4), vsubq_s16(vshll_n_s8(_r0, 2), vmull_s8(_r2, _v5))); + int16x8_t _tmp1 = vnegq_s16(vaddq_s16(_tmp12a, _tmp12b)); + int16x8_t _tmp2 = vsubq_s16(_tmp12a, _tmp12b); + int16x8_t _tmp3 = vaddq_s16(_tmp34b, _tmp34a); + int16x8_t _tmp4 = vsubq_s16(_tmp34b, _tmp34a); + int16x8_t _tmp5 = vaddq_s16(vmovl_s8(_r5), vsubq_s16(vshll_n_s8(_r1, 2), vmull_s8(_r3, _v5))); + + vst1q_s16(tmp[0][m], _tmp0); + vst1q_s16(tmp[1][m], _tmp1); + vst1q_s16(tmp[2][m], _tmp2); + vst1q_s16(tmp[3][m], _tmp3); + vst1q_s16(tmp[4][m], _tmp4); + vst1q_s16(tmp[5][m], _tmp5); + + r0 += w * elempack; + } + + int16x8_t _v5q = vdupq_n_s16(5); + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + short* p4 = p0 + max_jj * 8 * 4; + short* p5 = p0 + max_jj * 8 * 5; + + for (int m = 0; m < 6; m++) + { + int16x8_t _r0 = vld1q_s16(tmp[m][0]); + int16x8_t _r1 = vld1q_s16(tmp[m][1]); + int16x8_t _r2 = vld1q_s16(tmp[m][2]); + int16x8_t _r3 = vld1q_s16(tmp[m][3]); + int16x8_t _r4 = vld1q_s16(tmp[m][4]); + int16x8_t _r5 = vld1q_s16(tmp[m][5]); + + int16x8_t _tmp12a = vsubq_s16(_r3, vshlq_n_s16(_r1, 2)); + int16x8_t _tmp12b = vsubq_s16(_r4, vshlq_n_s16(_r2, 2)); + int16x8_t _tmp34a = vshlq_n_s16(vsubq_s16(_r3, _r1), 1); + int16x8_t _tmp34b = vsubq_s16(_r4, _r2); + + int16x8_t _tmp0 = vaddq_s16(_r4, vsubq_s16(vshlq_n_s16(_r0, 2), vmulq_s16(_r2, _v5q))); + int16x8_t _tmp1 = vaddq_s16(_tmp12b, _tmp12a); + int16x8_t _tmp2 = vsubq_s16(_tmp12b, _tmp12a); + int16x8_t _tmp3 = vaddq_s16(_tmp34b, _tmp34a); + int16x8_t _tmp4 = vsubq_s16(_tmp34b, _tmp34a); + int16x8_t _tmp5 = vaddq_s16(_r5, vsubq_s16(vshlq_n_s16(_r1, 2), vmulq_s16(_r3, _v5q))); + + vst1q_s16(p0, _tmp0); + vst1q_s16(p1, _tmp1); + vst1q_s16(p2, _tmp2); + vst1q_s16(p3, _tmp3); + vst1q_s16(p4, _tmp4); + vst1q_s16(p5, _tmp5); + + p0 += max_jj * 6 * 8; + p1 += max_jj * 6 * 8; + p2 += max_jj * 6 * 8; + p3 += max_jj * 6 * 8; + p4 += max_jj * 6 * 8; + p5 += max_jj * 6 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __ARM_NEON + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __ARM_NEON + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[6][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + signed char r40 = 0; + signed char r41 = 0; + signed char r50 = 0; + signed char r51 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 4 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 4 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 4 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + if (tj * 4 + 4 < w) + { + r40 = r0[4]; + r41 = r1[4]; + } + if (tj * 4 + 5 < w) + { + r50 = r0[5]; + r51 = r1[5]; + } + } + } + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + tmp[0][m][0] = r40 + r00 * 4 - r20 * 5; + tmp[0][m][1] = r41 + r01 * 4 - r21 * 5; + tmp[1][m][0] = tmp120b + tmp120a; + tmp[1][m][1] = tmp121b + tmp121a; + tmp[2][m][0] = tmp120b - tmp120a; + tmp[2][m][1] = tmp121b - tmp121a; + tmp[3][m][0] = tmp340b + tmp340a; + tmp[3][m][1] = tmp341b + tmp341a; + tmp[4][m][0] = tmp340b - tmp340a; + tmp[4][m][1] = tmp341b - tmp341a; + tmp[5][m][0] = r50 + r10 * 4 - r30 * 5; + tmp[5][m][1] = r51 + r11 * 4 - r31 * 5; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + short* p4 = p0 + max_jj * 2 * 4; + short* p5 = p0 + max_jj * 2 * 5; + + for (int m = 0; m < 6; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + short r40 = tmp[m][4][0]; + short r41 = tmp[m][4][1]; + short r50 = tmp[m][5][0]; + short r51 = tmp[m][5][1]; + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + p0[0] = r40 + r00 * 4 - r20 * 5; + p0[1] = r41 + r01 * 4 - r21 * 5; + p1[0] = tmp120b + tmp120a; + p1[1] = tmp121b + tmp121a; + p2[0] = tmp120b - tmp120a; + p2[1] = tmp121b - tmp121a; + p3[0] = tmp340b + tmp340a; + p3[1] = tmp341b + tmp341a; + p4[0] = tmp340b - tmp340a; + p4[1] = tmp341b - tmp341a; + p5[0] = r50 + r10 * 4 - r30 * 5; + p5[1] = r51 + r11 * 4 - r31 * 5; + + p0 += max_jj * 6 * 2; + p1 += max_jj * 6 * 2; + p2 += max_jj * 6 * 2; + p3 += max_jj * 6 * 2; + p4 += max_jj * 6 * 2; + p5 += max_jj * 6 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[6][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + signed char r4 = 0; + signed char r5 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 4 + 1 < w) r1 = r0123[1]; + if (tj * 4 + 2 < w) r2 = r0123[2]; + if (tj * 4 + 3 < w) r3 = r0123[3]; + if (tj * 4 + 4 < w) r4 = r0123[4]; + if (tj * 4 + 5 < w) r5 = r0123[5]; + } + } + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + tmp[0][m] = r4 + r0 * 4 - r2 * 5; + tmp[1][m] = tmp12b + tmp12a; + tmp[2][m] = tmp12b - tmp12a; + tmp[3][m] = tmp34b + tmp34a; + tmp[4][m] = tmp34b - tmp34a; + tmp[5][m] = r5 + r1 * 4 - r3 * 5; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + short* p4 = p0 + max_jj * 4; + short* p5 = p0 + max_jj * 5; + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + short r4 = tmp[m][4]; + short r5 = tmp[m][5]; + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + p0[0] = r4 + r0 * 4 - r2 * 5; + p1[0] = tmp12b + tmp12a; + p2[0] = tmp12b - tmp12a; + p3[0] = tmp34b + tmp34a; + p4[0] = tmp34b - tmp34a; + p5[0] = r5 + r1 * 4 - r3 * 5; + + p0 += max_jj * 6; + p1 += max_jj * 6; + p2 += max_jj * 6; + p3 += max_jj * 6; + p4 += max_jj * 6; + p5 += max_jj * 6; + } + } + } +} + +static inline void conv3x3s1_winograd43_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[4][6] = { + // {1, 1, 1, 1, 1, 0}, + // {0, 1, -1, 2, -2, 0}, + // {0, 1, 1, 4, 4, 0}, + // {0, 1, -1, 8, -8, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 3) / 4; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + int tmp[4][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + const int* r4 = r0 + max_jj * 8 * 4; + const int* r5 = r0 + max_jj * 8 * 5; + + for (int m = 0; m < 5; m++) + { + int32x4_t _r00 = vld1q_s32(r0); + int32x4_t _r01 = vld1q_s32(r0 + 4); + int32x4_t _r10 = vld1q_s32(r1); + int32x4_t _r11 = vld1q_s32(r1 + 4); + int32x4_t _r20 = vld1q_s32(r2); + int32x4_t _r21 = vld1q_s32(r2 + 4); + int32x4_t _r30 = vld1q_s32(r3); + int32x4_t _r31 = vld1q_s32(r3 + 4); + int32x4_t _r40 = vld1q_s32(r4); + int32x4_t _r41 = vld1q_s32(r4 + 4); + int32x4_t _r50 = vld1q_s32(r5); + int32x4_t _r51 = vld1q_s32(r5 + 4); + + int32x4_t _tmp02a0 = vaddq_s32(_r10, _r20); + int32x4_t _tmp02a1 = vaddq_s32(_r11, _r21); + int32x4_t _tmp02b0 = vaddq_s32(_r30, _r40); + int32x4_t _tmp02b1 = vaddq_s32(_r31, _r41); + int32x4_t _tmp13a0 = vsubq_s32(_r10, _r20); + int32x4_t _tmp13a1 = vsubq_s32(_r11, _r21); + int32x4_t _tmp13b0 = vsubq_s32(_r30, _r40); + int32x4_t _tmp13b1 = vsubq_s32(_r31, _r41); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_tmp02a0, _tmp02b0), _r00); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_tmp02a1, _tmp02b1), _r01); + int32x4_t _tmp10 = vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 1)); + int32x4_t _tmp11 = vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 1)); + int32x4_t _tmp20 = vaddq_s32(_tmp02a0, vshlq_n_s32(_tmp02b0, 2)); + int32x4_t _tmp21 = vaddq_s32(_tmp02a1, vshlq_n_s32(_tmp02b1, 2)); + int32x4_t _tmp30 = vaddq_s32(vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 3)), vshlq_n_s32(_r50, 2)); + int32x4_t _tmp31 = vaddq_s32(vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 3)), vshlq_n_s32(_r51, 2)); + + vst1q_s32(tmp[0][m], _tmp00); + vst1q_s32(tmp[0][m] + 4, _tmp01); + vst1q_s32(tmp[1][m], _tmp10); + vst1q_s32(tmp[1][m] + 4, _tmp11); + vst1q_s32(tmp[2][m], _tmp20); + vst1q_s32(tmp[2][m] + 4, _tmp21); + vst1q_s32(tmp[3][m], _tmp30); + vst1q_s32(tmp[3][m] + 4, _tmp31); + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + for (int m = 5; m < 6; m++) + { + int32x4_t _r00 = vld1q_s32(r0); + int32x4_t _r01 = vld1q_s32(r0 + 4); + int32x4_t _r10 = vld1q_s32(r1); + int32x4_t _r11 = vld1q_s32(r1 + 4); + int32x4_t _r20 = vld1q_s32(r2); + int32x4_t _r21 = vld1q_s32(r2 + 4); + int32x4_t _r30 = vld1q_s32(r3); + int32x4_t _r31 = vld1q_s32(r3 + 4); + int32x4_t _r40 = vld1q_s32(r4); + int32x4_t _r41 = vld1q_s32(r4 + 4); + int32x4_t _r50 = vld1q_s32(r5); + int32x4_t _r51 = vld1q_s32(r5 + 4); + + int32x4_t _tmp02a0 = vaddq_s32(_r10, _r20); + int32x4_t _tmp02a1 = vaddq_s32(_r11, _r21); + int32x4_t _tmp02b0 = vaddq_s32(_r30, _r40); + int32x4_t _tmp02b1 = vaddq_s32(_r31, _r41); + int32x4_t _tmp13a0 = vsubq_s32(_r10, _r20); + int32x4_t _tmp13a1 = vsubq_s32(_r11, _r21); + int32x4_t _tmp13b0 = vsubq_s32(_r30, _r40); + int32x4_t _tmp13b1 = vsubq_s32(_r31, _r41); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_tmp02a0, _tmp02b0), _r00); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_tmp02a1, _tmp02b1), _r01); + int32x4_t _tmp10 = vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 1)); + int32x4_t _tmp11 = vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 1)); + int32x4_t _tmp20 = vaddq_s32(_tmp02a0, vshlq_n_s32(_tmp02b0, 2)); + int32x4_t _tmp21 = vaddq_s32(_tmp02a1, vshlq_n_s32(_tmp02b1, 2)); + int32x4_t _tmp30 = vaddq_s32(vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 3)), vshlq_n_s32(_r50, 2)); + int32x4_t _tmp31 = vaddq_s32(vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 3)), vshlq_n_s32(_r51, 2)); + + _tmp00 = vshlq_n_s32(_tmp00, 2); + _tmp01 = vshlq_n_s32(_tmp01, 2); + _tmp10 = vshlq_n_s32(_tmp10, 2); + _tmp11 = vshlq_n_s32(_tmp11, 2); + _tmp20 = vshlq_n_s32(_tmp20, 2); + _tmp21 = vshlq_n_s32(_tmp21, 2); + _tmp30 = vshlq_n_s32(_tmp30, 2); + _tmp31 = vshlq_n_s32(_tmp31, 2); + + vst1q_s32(tmp[0][m], _tmp00); + vst1q_s32(tmp[0][m] + 4, _tmp01); + vst1q_s32(tmp[1][m], _tmp10); + vst1q_s32(tmp[1][m] + 4, _tmp11); + vst1q_s32(tmp[2][m], _tmp20); + vst1q_s32(tmp[2][m] + 4, _tmp21); + vst1q_s32(tmp[3][m], _tmp30); + vst1q_s32(tmp[3][m] + 4, _tmp31); + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int32x4_t _r00 = vld1q_s32(tmp[m][0]); + int32x4_t _r01 = vld1q_s32(tmp[m][0] + 4); + int32x4_t _r10 = vld1q_s32(tmp[m][1]); + int32x4_t _r11 = vld1q_s32(tmp[m][1] + 4); + int32x4_t _r20 = vld1q_s32(tmp[m][2]); + int32x4_t _r21 = vld1q_s32(tmp[m][2] + 4); + int32x4_t _r30 = vld1q_s32(tmp[m][3]); + int32x4_t _r31 = vld1q_s32(tmp[m][3] + 4); + int32x4_t _r40 = vld1q_s32(tmp[m][4]); + int32x4_t _r41 = vld1q_s32(tmp[m][4] + 4); + int32x4_t _r50 = vld1q_s32(tmp[m][5]); + int32x4_t _r51 = vld1q_s32(tmp[m][5] + 4); + + int32x4_t _tmp02a0 = vaddq_s32(_r10, _r20); + int32x4_t _tmp02a1 = vaddq_s32(_r11, _r21); + int32x4_t _tmp02b0 = vaddq_s32(_r30, _r40); + int32x4_t _tmp02b1 = vaddq_s32(_r31, _r41); + int32x4_t _tmp13a0 = vsubq_s32(_r10, _r20); + int32x4_t _tmp13a1 = vsubq_s32(_r11, _r21); + int32x4_t _tmp13b0 = vsubq_s32(_r30, _r40); + int32x4_t _tmp13b1 = vsubq_s32(_r31, _r41); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_tmp02a0, _tmp02b0), _r00); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_tmp02a1, _tmp02b1), _r01); + int32x4_t _tmp10 = vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 1)); + int32x4_t _tmp11 = vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 1)); + int32x4_t _tmp20 = vaddq_s32(_tmp02a0, vshlq_n_s32(_tmp02b0, 2)); + int32x4_t _tmp21 = vaddq_s32(_tmp02a1, vshlq_n_s32(_tmp02b1, 2)); + int32x4_t _tmp30 = vaddq_s32(vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 3)), _r50); + int32x4_t _tmp31 = vaddq_s32(vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 3)), _r51); + + // TODO use integer trick for division by 576 + float32x4_t _v576 = vdupq_n_f32(1.0 / 576); + _tmp00 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp00), _v576)); + _tmp01 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp01), _v576)); + _tmp10 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp10), _v576)); + _tmp11 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp11), _v576)); + _tmp20 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp20), _v576)); + _tmp21 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp21), _v576)); + _tmp30 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp30), _v576)); + _tmp31 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp31), _v576)); + + if (out_elempack == 8) + { + vst1q_s32(outptr0, _tmp00); + vst1q_s32(outptr0 + 4, _tmp01); + if (tj * 4 + 1 < outw) + { + vst1q_s32(outptr0 + 8, _tmp10); + vst1q_s32(outptr0 + 12, _tmp11); + } + if (tj * 4 + 2 < outw) + { + vst1q_s32(outptr0 + 16, _tmp20); + vst1q_s32(outptr0 + 20, _tmp21); + } + if (tj * 4 + 3 < outw) + { + vst1q_s32(outptr0 + 24, _tmp30); + vst1q_s32(outptr0 + 28, _tmp31); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + vst1q_s32(outptr0, _tmp00); + vst1q_s32(outptr1, _tmp01); + if (tj * 4 + 1 < outw) + { + vst1q_s32(outptr0 + 4, _tmp10); + vst1q_s32(outptr1 + 4, _tmp11); + } + if (tj * 4 + 2 < outw) + { + vst1q_s32(outptr0 + 8, _tmp20); + vst1q_s32(outptr1 + 8, _tmp21); + } + if (tj * 4 + 3 < outw) + { + vst1q_s32(outptr0 + 12, _tmp30); + vst1q_s32(outptr1 + 12, _tmp31); + } + } + if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = vgetq_lane_s32(_tmp00, 0); + outptr1[0] = vgetq_lane_s32(_tmp00, 1); + outptr2[0] = vgetq_lane_s32(_tmp00, 2); + outptr3[0] = vgetq_lane_s32(_tmp00, 3); + outptr4[0] = vgetq_lane_s32(_tmp01, 0); + outptr5[0] = vgetq_lane_s32(_tmp01, 1); + outptr6[0] = vgetq_lane_s32(_tmp01, 2); + outptr7[0] = vgetq_lane_s32(_tmp01, 3); + if (tj * 4 + 1 < outw) + { + outptr0[1] = vgetq_lane_s32(_tmp10, 0); + outptr1[1] = vgetq_lane_s32(_tmp10, 1); + outptr2[1] = vgetq_lane_s32(_tmp10, 2); + outptr3[1] = vgetq_lane_s32(_tmp10, 3); + outptr4[1] = vgetq_lane_s32(_tmp11, 0); + outptr5[1] = vgetq_lane_s32(_tmp11, 1); + outptr6[1] = vgetq_lane_s32(_tmp11, 2); + outptr7[1] = vgetq_lane_s32(_tmp11, 3); + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = vgetq_lane_s32(_tmp20, 0); + outptr1[2] = vgetq_lane_s32(_tmp20, 1); + outptr2[2] = vgetq_lane_s32(_tmp20, 2); + outptr3[2] = vgetq_lane_s32(_tmp20, 3); + outptr4[2] = vgetq_lane_s32(_tmp21, 0); + outptr5[2] = vgetq_lane_s32(_tmp21, 1); + outptr6[2] = vgetq_lane_s32(_tmp21, 2); + outptr7[2] = vgetq_lane_s32(_tmp21, 3); + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = vgetq_lane_s32(_tmp30, 0); + outptr1[3] = vgetq_lane_s32(_tmp30, 1); + outptr2[3] = vgetq_lane_s32(_tmp30, 2); + outptr3[3] = vgetq_lane_s32(_tmp30, 3); + outptr4[3] = vgetq_lane_s32(_tmp31, 0); + outptr5[3] = vgetq_lane_s32(_tmp31, 1); + outptr6[3] = vgetq_lane_s32(_tmp31, 2); + outptr7[3] = vgetq_lane_s32(_tmp31, 3); + } + } + + outptr0 += outw * out_elempack; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + int tmp[4][6][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + const int* r4 = r0 + max_jj * 4 * 4; + const int* r5 = r0 + max_jj * 4 * 5; + + for (int m = 0; m < 5; m++) + { + int32x4_t _r0 = vld1q_s32(r0); + int32x4_t _r1 = vld1q_s32(r1); + int32x4_t _r2 = vld1q_s32(r2); + int32x4_t _r3 = vld1q_s32(r3); + int32x4_t _r4 = vld1q_s32(r4); + int32x4_t _r5 = vld1q_s32(r5); + + int32x4_t _tmp02a = vaddq_s32(_r1, _r2); + int32x4_t _tmp02b = vaddq_s32(_r3, _r4); + int32x4_t _tmp13a = vsubq_s32(_r1, _r2); + int32x4_t _tmp13b = vsubq_s32(_r3, _r4); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_tmp02a, _tmp02b), _r0); + int32x4_t _tmp1 = vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 1)); + int32x4_t _tmp2 = vaddq_s32(_tmp02a, vshlq_n_s32(_tmp02b, 2)); + int32x4_t _tmp3 = vaddq_s32(vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 3)), vshlq_n_s32(_r5, 2)); + + vst1q_s32(tmp[0][m], _tmp0); + vst1q_s32(tmp[1][m], _tmp1); + vst1q_s32(tmp[2][m], _tmp2); + vst1q_s32(tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + for (int m = 5; m < 6; m++) + { + int32x4_t _r0 = vld1q_s32(r0); + int32x4_t _r1 = vld1q_s32(r1); + int32x4_t _r2 = vld1q_s32(r2); + int32x4_t _r3 = vld1q_s32(r3); + int32x4_t _r4 = vld1q_s32(r4); + int32x4_t _r5 = vld1q_s32(r5); + + int32x4_t _tmp02a = vaddq_s32(_r1, _r2); + int32x4_t _tmp02b = vaddq_s32(_r3, _r4); + int32x4_t _tmp13a = vsubq_s32(_r1, _r2); + int32x4_t _tmp13b = vsubq_s32(_r3, _r4); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_tmp02a, _tmp02b), _r0); + int32x4_t _tmp1 = vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 1)); + int32x4_t _tmp2 = vaddq_s32(_tmp02a, vshlq_n_s32(_tmp02b, 2)); + int32x4_t _tmp3 = vaddq_s32(vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 3)), vshlq_n_s32(_r5, 2)); + + _tmp0 = vshlq_n_s32(_tmp0, 2); + _tmp1 = vshlq_n_s32(_tmp1, 2); + _tmp2 = vshlq_n_s32(_tmp2, 2); + _tmp3 = vshlq_n_s32(_tmp3, 2); + + vst1q_s32(tmp[0][m], _tmp0); + vst1q_s32(tmp[1][m], _tmp1); + vst1q_s32(tmp[2][m], _tmp2); + vst1q_s32(tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int32x4_t _r0 = vld1q_s32(tmp[m][0]); + int32x4_t _r1 = vld1q_s32(tmp[m][1]); + int32x4_t _r2 = vld1q_s32(tmp[m][2]); + int32x4_t _r3 = vld1q_s32(tmp[m][3]); + int32x4_t _r4 = vld1q_s32(tmp[m][4]); + int32x4_t _r5 = vld1q_s32(tmp[m][5]); + + int32x4_t _tmp02a = vaddq_s32(_r1, _r2); + int32x4_t _tmp02b = vaddq_s32(_r3, _r4); + int32x4_t _tmp13a = vsubq_s32(_r1, _r2); + int32x4_t _tmp13b = vsubq_s32(_r3, _r4); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_tmp02a, _tmp02b), _r0); + int32x4_t _tmp1 = vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 1)); + int32x4_t _tmp2 = vaddq_s32(_tmp02a, vshlq_n_s32(_tmp02b, 2)); + int32x4_t _tmp3 = vaddq_s32(vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + float32x4_t _v576 = vdupq_n_f32(1.0 / 576); + _tmp0 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp0), _v576)); + _tmp1 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp1), _v576)); + _tmp2 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp2), _v576)); + _tmp3 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp3), _v576)); + + if (out_elempack == 4) + { + vst1q_s32(outptr0, _tmp0); + if (tj * 4 + 1 < outw) vst1q_s32(outptr0 + 4, _tmp1); + if (tj * 4 + 2 < outw) vst1q_s32(outptr0 + 8, _tmp2); + if (tj * 4 + 3 < outw) vst1q_s32(outptr0 + 12, _tmp3); + } + if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = vgetq_lane_s32(_tmp0, 0); + outptr1[0] = vgetq_lane_s32(_tmp0, 1); + outptr2[0] = vgetq_lane_s32(_tmp0, 2); + outptr3[0] = vgetq_lane_s32(_tmp0, 3); + if (tj * 4 + 1 < outw) + { + outptr0[1] = vgetq_lane_s32(_tmp1, 0); + outptr1[1] = vgetq_lane_s32(_tmp1, 1); + outptr2[1] = vgetq_lane_s32(_tmp1, 2); + outptr3[1] = vgetq_lane_s32(_tmp1, 3); + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = vgetq_lane_s32(_tmp2, 0); + outptr1[2] = vgetq_lane_s32(_tmp2, 1); + outptr2[2] = vgetq_lane_s32(_tmp2, 2); + outptr3[2] = vgetq_lane_s32(_tmp2, 3); + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = vgetq_lane_s32(_tmp3, 0); + outptr1[3] = vgetq_lane_s32(_tmp3, 1); + outptr2[3] = vgetq_lane_s32(_tmp3, 2); + outptr3[3] = vgetq_lane_s32(_tmp3, 3); + } + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[4][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + const int* r4 = r0 + max_jj * 2 * 4; + const int* r5 = r0 + max_jj * 2 * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + for (int m = 5; m < 6; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp00 = tmp00 * 4; + tmp01 = tmp01 * 4; + tmp10 = tmp10 * 4; + tmp11 = tmp11 * 4; + tmp20 = tmp20 * 4; + tmp21 = tmp21 * 4; + tmp30 = tmp30 * 4; + tmp31 = tmp31 * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a0 = tmp[m][1][0] + tmp[m][2][0]; + int tmp02a1 = tmp[m][1][1] + tmp[m][2][1]; + int tmp02b0 = tmp[m][3][0] + tmp[m][4][0]; + int tmp02b1 = tmp[m][3][1] + tmp[m][4][1]; + int tmp13a0 = tmp[m][1][0] - tmp[m][2][0]; + int tmp13a1 = tmp[m][1][1] - tmp[m][2][1]; + int tmp13b0 = tmp[m][3][0] - tmp[m][4][0]; + int tmp13b1 = tmp[m][3][1] - tmp[m][4][1]; + + int tmp00 = tmp02a0 + tmp02b0 + tmp[m][0][0]; + int tmp01 = tmp02a1 + tmp02b1 + tmp[m][0][1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + tmp[m][5][0]; + int tmp31 = tmp13a1 + tmp13b1 * 8 + tmp[m][5][1]; + + tmp00 = tmp00 / 576; + tmp01 = tmp01 / 576; + tmp10 = tmp10 / 576; + tmp11 = tmp11 / 576; + tmp20 = tmp20 / 576; + tmp21 = tmp21 / 576; + tmp30 = tmp30 / 576; + tmp31 = tmp31 / 576; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp20; + outptr1[2] = tmp21; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp30; + outptr1[3] = tmp31; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[4][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + const int* r4 = r0 + max_jj * 4; + const int* r5 = r0 + max_jj * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + for (int m = 5; m < 6; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp0 = tmp0 * 4; + tmp1 = tmp1 * 4; + tmp2 = tmp2 * 4; + tmp3 = tmp3 * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a = tmp[m][1] + tmp[m][2]; + int tmp02b = tmp[m][3] + tmp[m][4]; + int tmp13a = tmp[m][1] - tmp[m][2]; + int tmp13b = tmp[m][3] - tmp[m][4]; + + int tmp0 = tmp02a + tmp02b + tmp[m][0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + tmp[m][5]; + + tmp0 = tmp0 / 576; + tmp1 = tmp1 / 576; + tmp2 = tmp2 / 576; + tmp3 = tmp3 / 576; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 4 + 1 < outw) outptr0[1] = tmp1; + if (tj * 4 + 2 < outw) outptr0[2] = tmp2; + if (tj * 4 + 3 < outw) outptr0[3] = tmp3; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd43_int8(Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 4n+2, winograd F(4,3) + int w_tiles = (outw + 3) / 4; + int h_tiles = (outh + 3) / 4; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 36; + + // NCNN_LOGE("conv3x3s1_winograd43_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 2u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 2u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + bottom_blob.release(); + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk); + } + + // transform output + conv3x3s1_winograd43_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} diff --git a/src/layer/arm/convolution_arm.cpp b/src/layer/arm/convolution_arm.cpp index c8d48aec762..849a8daea6b 100644 --- a/src/layer/arm/convolution_arm.cpp +++ b/src/layer/arm/convolution_arm.cpp @@ -49,10 +49,9 @@ namespace ncnn { #if NCNN_INT8 #include "convolution_im2col_gemm_int8.h" +#include "convolution_3x3_winograd_int8.h" -#include "convolution_winograd_transform_int8.h" -#include "convolution_winograd_dot_int8.h" -#include "convolution_3x3_int8.h" +// #include "convolution_3x3_int8.h" #include "convolution_int8.h" #endif // NCNN_INT8 @@ -74,12 +73,6 @@ namespace ncnn { #include "convolution_pack8to4_int8.h" #include "convolution_pack1to4_int8.h" #include "convolution_pack8to1_int8.h" -#include "convolution_winograd_transform_pack4_int8.h" -#include "convolution_winograd_transform_pack8_int8.h" -#include "convolution_winograd_dot_pack8to4_int8.h" -#include "convolution_winograd_dot_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" -#include "convolution_3x3_pack8to1_int8.h" #endif // NCNN_INT8 #endif // __ARM_NEON @@ -1285,6 +1278,14 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) const int maxk = kernel_w * kernel_h; const int num_input = weight_data_size / maxk / num_output; + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input >= 8 && num_output >= 8) && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1; +#if NCNN_ARM82DOT + if (ncnn::cpu_support_arm_asimddp()) + { + prefer_winograd = false; + } +#endif + int elempack = 1; int out_elempack = 1; #if __ARM_NEON @@ -1295,25 +1296,12 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) } #endif // __ARM_NEON -#if NCNN_ARM82DOT - if (elempack == 8 && out_elempack == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && (!ncnn::cpu_support_arm_asimddp() || (ncnn::cpu_support_arm_asimddp() && num_input >= 256 && num_output >= 256))) -#else - if (elempack == 8 && out_elempack == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) -#endif - { -#if __ARM_NEON - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_neon(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __ARM_NEON - } - else if (elempack == 8 && out_elempack == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + if (opt.use_winograd_convolution && prefer_winograd) { -#if __ARM_NEON - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_neon(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __ARM_NEON - } - else if (elempack == 1 && out_elempack == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_winograd43_transform_kernel_int8_neon(weight_data, weight_winograd43_data, num_input, num_output, opt); + if (opt.use_winograd43_convolution) + conv3x3s1_winograd43_transform_kernel_int8(weight_data, weight_winograd43_data, num_input, num_output, opt); + else + conv3x3s1_winograd23_transform_kernel_int8(weight_data, weight_winograd23_data, num_input, num_output, opt); } else if (opt.use_sgemm_convolution) { @@ -1321,10 +1309,6 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) } else if (elempack == 1 && out_elempack == 1) { - // if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) - // { - // conv3x3s2_transform_kernel_int8_neon(weight_data, weight_3x3s2_data_int8, num_input, num_output); - // } weight_data_tm = weight_data; } else @@ -1405,20 +1389,29 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con // NCNN_LOGE("forward_int8_arm %d %d %d %d %d", w, h, bottom_blob_bordered.c, elempack, out_elempack); - top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); - if (top_blob.empty()) - return -100; - -#if NCNN_ARM82DOT int channels = bottom_blob_bordered.c; const int num_input = channels * elempack; + + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input >= 8 && num_output >= 8) && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1; +#if NCNN_ARM82DOT + if (ncnn::cpu_support_arm_asimddp()) + { + prefer_winograd = false; + } #endif int out_elempack_int32 = 1; #if __ARM_NEON if (opt.use_packing_layout) { - out_elempack_int32 = num_output % 4 == 0 ? 4 : 1; + if ((opt.use_winograd_convolution && prefer_winograd) || opt.use_sgemm_convolution) + { + out_elempack_int32 = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; + } + else + { + out_elempack_int32 = num_output % 4 == 0 ? 4 : 1; + } } #endif // __ARM_NEON @@ -1435,25 +1428,12 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con NCNN_LOGE("opt.num_threads %d changed, convolution gemm will use load-time value %d", opt.num_threads, nT); } -#if NCNN_ARM82DOT - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && (!ncnn::cpu_support_arm_asimddp() || (ncnn::cpu_support_arm_asimddp() && num_input >= 256 && num_output >= 256))) -#else - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) -#endif + if (opt.use_winograd_convolution && prefer_winograd) { -#if __ARM_NEON - conv3x3s1_winograd43_pack8to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __ARM_NEON - } - else if (elempack == 8 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __ARM_NEON - conv3x3s1_winograd43_pack8to1_int8_neon(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __ARM_NEON - } - else if (elempack == 1 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_winograd43_int8_neon(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); + if (opt.use_winograd43_convolution && !weight_winograd43_data.empty()) + conv3x3s1_winograd43_int8(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, _nT, opt); + else + conv3x3s1_winograd23_int8(bottom_blob_bordered, top_blob_int32, weight_winograd23_data, _nT, opt); } else if (opt.use_sgemm_convolution) { @@ -1478,6 +1458,12 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con convolution_int8(bottom_blob_bordered, top_blob_int32, weight_data_tm, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); } + bottom_blob_bordered.release(); + + top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + if (use_int8_requantize) { requantize_from_int32_to_int8(top_blob_int32, top_blob, scale_in_data, top_blob_int8_scales, bias_data, activation_type, activation_params, opt); diff --git a/src/layer/arm/convolution_im2col_gemm_int8.h b/src/layer/arm/convolution_im2col_gemm_int8.h index 1fecc86448b..63a1df9c32b 100644 --- a/src/layer/arm/convolution_im2col_gemm_int8.h +++ b/src/layer/arm/convolution_im2col_gemm_int8.h @@ -654,7 +654,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M // NCNN_LOGE("convolution_gemm_transB_packed_tile_int8 %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); const int out_elempack = top_blob.elempack; - const int out_hstep = (int)top_blob.cstep; + const size_t out_hstep = top_blob.cstep; const signed char* pAT = AT_tile; const signed char* pBT = BT_tile; @@ -1150,7 +1150,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M "beq 8f \n" // if out_elempack == 8 - "cmp %11, #8 \n" + "cmp %w11, #8 \n" "bne 7f \n" "st1 {v16.4s}, [%3], #16 \n" @@ -1304,7 +1304,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M "zip2 v15.4s, v21.4s, v25.4s \n" // if out_elempack == 8 - "cmp %11, #8 \n" + "cmp %w11, #8 \n" "bne 7f \n" // to @@ -2737,7 +2737,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M "beq 8f \n" // if out_elempack == 8 - "cmp %11, #8 \n" + "cmp %w11, #8 \n" "bne 7f \n" "st1 {v16.4s}, [%3], #16 \n" @@ -2833,7 +2833,7 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M "zip2 v7.4s, v19.4s, v21.4s \n" // if out_elempack == 8 - "cmp %11, #8 \n" + "cmp %w11, #8 \n" "bne 7f \n" // to @@ -7636,7 +7636,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo if (elempack == 8) { const signed char* p0 = (const signed char*)bottom_blob.channel(k / 8) + (j + jj) * 8; - const int cstep = bottom_blob.cstep * 8; + const size_t cstep = bottom_blob.cstep * 8; int kk = 0; #if __ARM_FEATURE_MATMUL_INT8 @@ -7725,7 +7725,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo if (elempack == 1) { const signed char* p0 = (const signed char*)bottom_blob.channel(k) + (j + jj); - const int cstep = bottom_blob.cstep; + const size_t cstep = bottom_blob.cstep; int kk = 0; #if __ARM_FEATURE_DOTPROD @@ -7855,7 +7855,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo if (elempack == 8) { const signed char* p0 = (const signed char*)bottom_blob.channel(k / 8) + (j + jj) * 8; - const int cstep = bottom_blob.cstep * 8; + const size_t cstep = bottom_blob.cstep * 8; int kk = 0; #if __ARM_FEATURE_MATMUL_INT8 @@ -7950,7 +7950,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo if (elempack == 1) { const signed char* p0 = (const signed char*)bottom_blob.channel(k) + (j + jj); - const int cstep = bottom_blob.cstep; + const size_t cstep = bottom_blob.cstep; int kk = 0; #if __ARM_FEATURE_DOTPROD @@ -8046,7 +8046,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo if (elempack == 8) { const signed char* p0 = (const signed char*)bottom_blob.channel(k / 8) + (j + jj) * 8; - const int cstep = bottom_blob.cstep * 8; + const size_t cstep = bottom_blob.cstep * 8; int kk = 0; #if __ARM_FEATURE_MATMUL_INT8 @@ -8135,7 +8135,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo if (elempack == 1) { const signed char* p0 = (const signed char*)bottom_blob.channel(k) + (j + jj); - const int cstep = bottom_blob.cstep; + const size_t cstep = bottom_blob.cstep; int kk = 0; #if __ARM_NEON @@ -8202,7 +8202,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo if (elempack == 8) { const signed char* p0 = (const signed char*)bottom_blob.channel(k / 8) + (j + jj) * 8; - const int cstep = bottom_blob.cstep * 8; + const size_t cstep = bottom_blob.cstep * 8; int kk = 0; for (; kk < max_kk / 8; kk++) @@ -8217,7 +8217,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo if (elempack == 1) { const signed char* p0 = (const signed char*)bottom_blob.channel(k) + (j + jj); - const int cstep = bottom_blob.cstep; + const size_t cstep = bottom_blob.cstep; int kk = 0; for (; kk < max_kk; kk++) diff --git a/src/layer/arm/convolution_winograd_dot_int8.h b/src/layer/arm/convolution_winograd_dot_int8.h deleted file mode 100644 index d5cf1bcd87e..00000000000 --- a/src/layer/arm/convolution_winograd_dot_int8.h +++ /dev/null @@ -1,1005 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void convolution_winograd_dot_int8_neon(Mat& bottom_blob_tm, int outch, const Mat& kernel_tm, Mat& top_blob_tm, const Option& opt) -{ - // Mat bottom_blob_tm(tiles, 16/36/64, inch, 2u, 1, opt.workspace_allocator); - - const int tiles = bottom_blob_tm.w; - const int batch = bottom_blob_tm.h; - const int inch = bottom_blob_tm.c; - - // permute - Mat bottom_blob_tm2; -#if __ARM_NEON -#if __aarch64__ - if (tiles >= 8) - bottom_blob_tm2.create(inch, tiles / 8 + (tiles % 8) / 4 + tiles % 4, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 4) - bottom_blob_tm2.create(inch, tiles / 4 + tiles % 4, batch, 8u, 4, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(inch, tiles, batch, 2u, 1, opt.workspace_allocator); -#else - if (tiles >= 4) - bottom_blob_tm2.create(inch, tiles / 4 + tiles % 4, batch, 8u, 4, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(inch, tiles, batch, 2u, 1, opt.workspace_allocator); -#endif -#else // __ARM_NEON - if (tiles >= 2) - bottom_blob_tm2.create(inch, tiles / 2 + tiles % 2, batch, 4u, 2, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(inch, tiles, batch, 2u, 1, opt.workspace_allocator); -#endif // __ARM_NEON - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < batch; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __ARM_NEON -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - short* tmpptr = tm2.row(i / 8); - const short* r0 = (const short*)bottom_blob_tm + r * tiles + i; - - int q = 0; - for (; q < inch; q++) - { - int16x8_t _r0 = vld1q_s16(r0); - vst1q_s16(tmpptr, _r0); - r0 += bottom_blob_tm.cstep; - tmpptr += 8; - } - } -#endif - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - short* tmpptr = tm2.row(i / 8 + (i % 8) / 4); -#else - short* tmpptr = tm2.row(i / 4); -#endif - const short* r0 = (const short*)bottom_blob_tm + r * tiles + i; - - int q = 0; - for (; q < inch; q++) - { - int16x4_t _r0 = vld1_s16(r0); - vst1_s16(tmpptr, _r0); - r0 += bottom_blob_tm.cstep; - tmpptr += 4; - } - } -#else // __ARM_NEON - for (; i + 1 < tiles; i += 2) - { - short* tmpptr = tm2.row(i / 2); - const short* r0 = (const short*)bottom_blob_tm + r * tiles + i; - - int q = 0; -#if __ARM_FEATURE_SIMD32 - for (; q + 1 < inch; q += 2) - { - tmpptr[0] = r0[0]; - tmpptr[2] = r0[1]; - r0 += bottom_blob_tm.cstep; - tmpptr[1] = r0[0]; - tmpptr[3] = r0[1]; - r0 += bottom_blob_tm.cstep; - tmpptr += 4; - } -#endif // __ARM_FEATURE_SIMD32 - for (; q < inch; q++) - { - tmpptr[0] = r0[0]; - tmpptr[1] = r0[1]; - r0 += bottom_blob_tm.cstep; - tmpptr += 2; - } - } -#endif // __ARM_NEON - for (; i < tiles; i++) - { -#if __ARM_NEON -#if __aarch64__ - short* tmpptr = tm2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - short* tmpptr = tm2.row(i / 4 + i % 4); -#endif -#else - short* tmpptr = tm2.row(i / 2 + i % 2); -#endif - const short* r0 = (const short*)bottom_blob_tm + r * tiles + i; - - int q = 0; - for (; q < inch; q++) - { - tmpptr[0] = r0[0]; - r0 += bottom_blob_tm.cstep; - tmpptr += 1; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, batch, outch, 4u, 1, opt.workspace_allocator); - -#if __ARM_NEON - int nn_outch = outch >> 3; - int remain_outch_start = nn_outch << 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 8; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - int* output4_tm = top_blob_tm.channel(p + 4); - int* output5_tm = top_blob_tm.channel(p + 5); - int* output6_tm = top_blob_tm.channel(p + 6); - int* output7_tm = top_blob_tm.channel(p + 7); - - const Mat kernel0_tm = kernel_tm.channel(p / 8); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum00 = vdupq_n_s32(0); - int32x4_t _sum10 = vdupq_n_s32(0); - int32x4_t _sum20 = vdupq_n_s32(0); - int32x4_t _sum30 = vdupq_n_s32(0); - int32x4_t _sum40 = vdupq_n_s32(0); - int32x4_t _sum50 = vdupq_n_s32(0); - int32x4_t _sum60 = vdupq_n_s32(0); - int32x4_t _sum70 = vdupq_n_s32(0); - int32x4_t _sum01 = vdupq_n_s32(0); - int32x4_t _sum11 = vdupq_n_s32(0); - int32x4_t _sum21 = vdupq_n_s32(0); - int32x4_t _sum31 = vdupq_n_s32(0); - int32x4_t _sum41 = vdupq_n_s32(0); - int32x4_t _sum51 = vdupq_n_s32(0); - int32x4_t _sum61 = vdupq_n_s32(0); - int32x4_t _sum71 = vdupq_n_s32(0); - - int j = 0; - for (; j < inch; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - - _sum00 = vmlal_lane_s16(_sum00, vget_low_s16(_val0), vget_low_s16(_w0), 0); - _sum10 = vmlal_lane_s16(_sum10, vget_low_s16(_val0), vget_low_s16(_w0), 1); - _sum20 = vmlal_lane_s16(_sum20, vget_low_s16(_val0), vget_low_s16(_w0), 2); - _sum30 = vmlal_lane_s16(_sum30, vget_low_s16(_val0), vget_low_s16(_w0), 3); - _sum40 = vmlal_lane_s16(_sum40, vget_low_s16(_val0), vget_high_s16(_w0), 0); - _sum50 = vmlal_lane_s16(_sum50, vget_low_s16(_val0), vget_high_s16(_w0), 1); - _sum60 = vmlal_lane_s16(_sum60, vget_low_s16(_val0), vget_high_s16(_w0), 2); - _sum70 = vmlal_lane_s16(_sum70, vget_low_s16(_val0), vget_high_s16(_w0), 3); - - _sum01 = vmlal_lane_s16(_sum01, vget_high_s16(_val0), vget_low_s16(_w0), 0); - _sum11 = vmlal_lane_s16(_sum11, vget_high_s16(_val0), vget_low_s16(_w0), 1); - _sum21 = vmlal_lane_s16(_sum21, vget_high_s16(_val0), vget_low_s16(_w0), 2); - _sum31 = vmlal_lane_s16(_sum31, vget_high_s16(_val0), vget_low_s16(_w0), 3); - _sum41 = vmlal_lane_s16(_sum41, vget_high_s16(_val0), vget_high_s16(_w0), 0); - _sum51 = vmlal_lane_s16(_sum51, vget_high_s16(_val0), vget_high_s16(_w0), 1); - _sum61 = vmlal_lane_s16(_sum61, vget_high_s16(_val0), vget_high_s16(_w0), 2); - _sum71 = vmlal_lane_s16(_sum71, vget_high_s16(_val0), vget_high_s16(_w0), 3); - - r0 += 8; - k0 += 8; - } - - vst1q_s32(output0_tm, _sum00); - vst1q_s32(output0_tm + 4, _sum01); - vst1q_s32(output1_tm, _sum10); - vst1q_s32(output1_tm + 4, _sum11); - vst1q_s32(output2_tm, _sum20); - vst1q_s32(output2_tm + 4, _sum21); - vst1q_s32(output3_tm, _sum30); - vst1q_s32(output3_tm + 4, _sum31); - vst1q_s32(output4_tm, _sum40); - vst1q_s32(output4_tm + 4, _sum41); - vst1q_s32(output5_tm, _sum50); - vst1q_s32(output5_tm + 4, _sum51); - vst1q_s32(output6_tm, _sum60); - vst1q_s32(output6_tm + 4, _sum61); - vst1q_s32(output7_tm, _sum70); - vst1q_s32(output7_tm + 4, _sum71); - - output0_tm += 8; - output1_tm += 8; - output2_tm += 8; - output3_tm += 8; - output4_tm += 8; - output5_tm += 8; - output6_tm += 8; - output7_tm += 8; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - int j = 0; - for (; j + 1 < inch; j += 2) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val01), vget_low_s16(_w0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val01), vget_low_s16(_w0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val01), vget_low_s16(_w0), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val01), vget_low_s16(_w0), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val01), vget_high_s16(_w0), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val01), vget_high_s16(_w0), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val01), vget_high_s16(_w0), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val01), vget_high_s16(_w0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val01), vget_low_s16(_w1), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val01), vget_low_s16(_w1), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val01), vget_low_s16(_w1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val01), vget_low_s16(_w1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val01), vget_high_s16(_w1), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val01), vget_high_s16(_w1), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val01), vget_high_s16(_w1), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val01), vget_high_s16(_w1), 3); - - r0 += 8; - k0 += 16; - } - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, _val0, vget_low_s16(_w0), 0); - _sum1 = vmlal_lane_s16(_sum1, _val0, vget_low_s16(_w0), 1); - _sum2 = vmlal_lane_s16(_sum2, _val0, vget_low_s16(_w0), 2); - _sum3 = vmlal_lane_s16(_sum3, _val0, vget_low_s16(_w0), 3); - _sum4 = vmlal_lane_s16(_sum4, _val0, vget_high_s16(_w0), 0); - _sum5 = vmlal_lane_s16(_sum5, _val0, vget_high_s16(_w0), 1); - _sum6 = vmlal_lane_s16(_sum6, _val0, vget_high_s16(_w0), 2); - _sum7 = vmlal_lane_s16(_sum7, _val0, vget_high_s16(_w0), 3); - - r0 += 4; - k0 += 8; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output2_tm, _sum2); - vst1q_s32(output3_tm, _sum3); - vst1q_s32(output4_tm, _sum4); - vst1q_s32(output5_tm, _sum5); - vst1q_s32(output6_tm, _sum6); - vst1q_s32(output7_tm, _sum7); - - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - output4_tm += 4; - output5_tm += 4; - output6_tm += 4; - output7_tm += 4; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - int j = 0; - for (; j + 3 < inch; j += 4) - { - int16x4_t _val0123 = vld1_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), _val0123, 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), _val0123, 0); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), _val0123, 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), _val0123, 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), _val0123, 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), _val0123, 2); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), _val0123, 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), _val0123, 3); - - r0 += 4; - k0 += 32; - } - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_dup_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_s16(_sum0, _val0, vget_low_s16(_w0)); - _sum1 = vmlal_s16(_sum1, _val0, vget_high_s16(_w0)); - - r0 += 1; - k0 += 8; - } - - output0_tm[0] = vgetq_lane_s32(_sum0, 0); - output1_tm[0] = vgetq_lane_s32(_sum0, 1); - output2_tm[0] = vgetq_lane_s32(_sum0, 2); - output3_tm[0] = vgetq_lane_s32(_sum0, 3); - output4_tm[0] = vgetq_lane_s32(_sum1, 0); - output5_tm[0] = vgetq_lane_s32(_sum1, 1); - output6_tm[0] = vgetq_lane_s32(_sum1, 2); - output7_tm[0] = vgetq_lane_s32(_sum1, 3); - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - output4_tm += 1; - output5_tm += 1; - output6_tm += 1; - output7_tm += 1; - } - } - } - - nn_outch = (outch - remain_outch_start) >> 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = remain_outch_start + pp * 4; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - - const Mat kernel0_tm = kernel_tm.channel(p / 8 + (p % 8) / 4); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum00 = vdupq_n_s32(0); - int32x4_t _sum10 = vdupq_n_s32(0); - int32x4_t _sum20 = vdupq_n_s32(0); - int32x4_t _sum30 = vdupq_n_s32(0); - int32x4_t _sum01 = vdupq_n_s32(0); - int32x4_t _sum11 = vdupq_n_s32(0); - int32x4_t _sum21 = vdupq_n_s32(0); - int32x4_t _sum31 = vdupq_n_s32(0); - - int j = 0; - for (; j + 1 < inch; j += 2) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _val23 = vld1q_s16(r0 + 8); - int16x8_t _w01 = vld1q_s16(k0); - - _sum00 = vmlal_lane_s16(_sum00, vget_low_s16(_val01), vget_low_s16(_w01), 0); - _sum10 = vmlal_lane_s16(_sum10, vget_low_s16(_val01), vget_low_s16(_w01), 1); - _sum20 = vmlal_lane_s16(_sum20, vget_low_s16(_val01), vget_low_s16(_w01), 2); - _sum30 = vmlal_lane_s16(_sum30, vget_low_s16(_val01), vget_low_s16(_w01), 3); - _sum01 = vmlal_lane_s16(_sum01, vget_high_s16(_val01), vget_low_s16(_w01), 0); - _sum11 = vmlal_lane_s16(_sum11, vget_high_s16(_val01), vget_low_s16(_w01), 1); - _sum21 = vmlal_lane_s16(_sum21, vget_high_s16(_val01), vget_low_s16(_w01), 2); - _sum31 = vmlal_lane_s16(_sum31, vget_high_s16(_val01), vget_low_s16(_w01), 3); - - _sum00 = vmlal_lane_s16(_sum00, vget_low_s16(_val23), vget_high_s16(_w01), 0); - _sum10 = vmlal_lane_s16(_sum10, vget_low_s16(_val23), vget_high_s16(_w01), 1); - _sum20 = vmlal_lane_s16(_sum20, vget_low_s16(_val23), vget_high_s16(_w01), 2); - _sum30 = vmlal_lane_s16(_sum30, vget_low_s16(_val23), vget_high_s16(_w01), 3); - _sum01 = vmlal_lane_s16(_sum01, vget_high_s16(_val23), vget_high_s16(_w01), 0); - _sum11 = vmlal_lane_s16(_sum11, vget_high_s16(_val23), vget_high_s16(_w01), 1); - _sum21 = vmlal_lane_s16(_sum21, vget_high_s16(_val23), vget_high_s16(_w01), 2); - _sum31 = vmlal_lane_s16(_sum31, vget_high_s16(_val23), vget_high_s16(_w01), 3); - - r0 += 16; - k0 += 8; - } - for (; j < inch; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x4_t _w0 = vld1_s16(k0); - - _sum00 = vmlal_lane_s16(_sum00, vget_low_s16(_val0), _w0, 0); - _sum10 = vmlal_lane_s16(_sum10, vget_low_s16(_val0), _w0, 1); - _sum20 = vmlal_lane_s16(_sum20, vget_low_s16(_val0), _w0, 2); - _sum30 = vmlal_lane_s16(_sum30, vget_low_s16(_val0), _w0, 3); - _sum01 = vmlal_lane_s16(_sum01, vget_high_s16(_val0), _w0, 0); - _sum11 = vmlal_lane_s16(_sum11, vget_high_s16(_val0), _w0, 1); - _sum21 = vmlal_lane_s16(_sum21, vget_high_s16(_val0), _w0, 2); - _sum31 = vmlal_lane_s16(_sum31, vget_high_s16(_val0), _w0, 3); - - r0 += 8; - k0 += 4; - } - - vst1q_s32(output0_tm, _sum00); - vst1q_s32(output0_tm + 4, _sum01); - vst1q_s32(output1_tm, _sum10); - vst1q_s32(output1_tm + 4, _sum11); - vst1q_s32(output2_tm, _sum20); - vst1q_s32(output2_tm + 4, _sum21); - vst1q_s32(output3_tm, _sum30); - vst1q_s32(output3_tm + 4, _sum31); - - output0_tm += 8; - output1_tm += 8; - output2_tm += 8; - output3_tm += 8; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - int j = 0; - for (; j + 1 < inch; j += 2) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _w01 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val01), vget_low_s16(_w01), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val01), vget_low_s16(_w01), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val01), vget_low_s16(_w01), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val01), vget_low_s16(_w01), 3); - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val01), vget_high_s16(_w01), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val01), vget_high_s16(_w01), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val01), vget_high_s16(_w01), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val01), vget_high_s16(_w01), 3); - - r0 += 8; - k0 += 8; - } - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_s16(r0); - int16x4_t _w0 = vld1_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, _val0, _w0, 0); - _sum1 = vmlal_lane_s16(_sum1, _val0, _w0, 1); - _sum2 = vmlal_lane_s16(_sum2, _val0, _w0, 2); - _sum3 = vmlal_lane_s16(_sum3, _val0, _w0, 3); - - r0 += 4; - k0 += 4; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output2_tm, _sum2); - vst1q_s32(output3_tm, _sum3); - - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - int j = 0; - for (; j + 3 < inch; j += 4) - { - int16x4_t _val0123 = vld1_s16(r0); - int16x8_t _w01 = vld1q_s16(k0); - int16x8_t _w23 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w01), _val0123, 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w01), _val0123, 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w23), _val0123, 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w23), _val0123, 3); - - r0 += 4; - k0 += 16; - } - _sum0 = vaddq_s32(_sum0, _sum1); - _sum2 = vaddq_s32(_sum2, _sum3); - _sum0 = vaddq_s32(_sum0, _sum2); - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_dup_s16(r0); - int16x4_t _w0 = vld1_s16(k0); - - _sum0 = vmlal_s16(_sum0, _val0, _w0); - - r0 += 1; - k0 += 4; - } - - output0_tm[0] = vgetq_lane_s32(_sum0, 0); - output1_tm[0] = vgetq_lane_s32(_sum0, 1); - output2_tm[0] = vgetq_lane_s32(_sum0, 2); - output3_tm[0] = vgetq_lane_s32(_sum0, 3); - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - } - } - } - - remain_outch_start += nn_outch << 2; -#else // __ARM_NEON - int nn_outch = outch >> 1; - int remain_outch_start = nn_outch << 1; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 2; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - - const Mat kernel0_tm = kernel_tm.channel(p / 2); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; - for (; i + 1 < tiles; i += 2) - { - const short* r0 = bb2.row(i / 2); - const short* k0 = kernel0_tm.row(r); - - int sum00 = 0; - int sum10 = 0; - int sum01 = 0; - int sum11 = 0; - - int j = 0; -#if __ARM_FEATURE_SIMD32 - for (; j + 1 < inch; j += 2) - { - // fomit-frame-pointer implied in optimized flag spare one register - // let us stay away from error: ‘asm’ operand has impossible constraints --- nihui -#if __OPTIMIZE__ - asm volatile( - "ldr r2, [%0], #4 \n" // int16x2_t _val02 = *((int16x2_t*)r0); r0 += 2; - "ldr r3, [%0], #4 \n" // int16x2_t _val13 = *((int16x2_t*)r0); r0 += 2; - "ldr r4, [%1], #4 \n" // int16x2_t _w02 = *((int16x2_t*)k0); k0 += 2; - "ldr r5, [%1], #4 \n" // int16x2_t _w13 = *((int16x2_t*)k0); k0 += 2; - "smlad %2, r2, r4, %2 \n" // sum00 = __smlad(_val02, _w02, sum00); - "smlad %3, r3, r4, %3 \n" // sum01 = __smlad(_val13, _w02, sum01); - "smlad %4, r2, r5, %4 \n" // sum10 = __smlad(_val02, _w13, sum10); - "smlad %5, r3, r5, %5 \n" // sum11 = __smlad(_val13, _w13, sum11); - : "=r"(r0), - "=r"(k0), - "=r"(sum00), - "=r"(sum01), - "=r"(sum10), - "=r"(sum11) - : "0"(r0), - "1"(k0), - "2"(sum00), - "3"(sum01), - "4"(sum10), - "5"(sum11) - : "memory", "r2", "r3", "r4", "r5"); -#else - int _val02 = *((int*)r0); - int _val13 = *((int*)(r0 + 2)); - int _w02 = *((int*)k0); - int _w13 = *((int*)(k0 + 2)); - asm volatile("smlad %0, %2, %3, %0" - : "=r"(sum00) - : "0"(sum00), "r"(_val02), "r"(_w02) - :); - asm volatile("smlad %0, %2, %3, %0" - : "=r"(sum01) - : "0"(sum01), "r"(_val13), "r"(_w02) - :); - asm volatile("smlad %0, %2, %3, %0" - : "=r"(sum10) - : "0"(sum10), "r"(_val02), "r"(_w13) - :); - asm volatile("smlad %0, %2, %3, %0" - : "=r"(sum11) - : "0"(sum11), "r"(_val13), "r"(_w13) - :); - r0 += 4; - k0 += 4; -#endif - } -#endif // __ARM_FEATURE_SIMD32 - for (; j < inch; j++) - { - signed short val0 = r0[0]; - signed short val1 = r0[1]; - - signed short w0 = k0[0]; - signed short w1 = k0[1]; - - sum00 += val0 * w0; - sum10 += val0 * w1; - sum01 += val1 * w0; - sum11 += val1 * w1; - - r0 += 2; - k0 += 2; - } - - output0_tm[0] = sum00; - output1_tm[0] = sum10; - output0_tm[1] = sum01; - output1_tm[1] = sum11; - output0_tm += 2; - output1_tm += 2; - } - for (; i < tiles; i++) - { - const short* r0 = bb2.row(i / 2 + i % 2); - const short* k0 = kernel0_tm.row(r); - - int sum0 = 0; - int sum1 = 0; - - int j = 0; -#if __ARM_FEATURE_SIMD32 - for (; j + 1 < inch; j += 2) - { - asm volatile( - "ldr r2, [%0], #4 \n" // int16x2_t _val01 = *((int16x2_t*)r0); r0 += 2; - "ldr r3, [%1], #4 \n" // int16x2_t _w02 = *((int16x2_t*)k0); k0 += 2; - "ldr r4, [%1], #4 \n" // int16x2_t _w13 = *((int16x2_t*)k0); k0 += 2; - "smlad %2, r2, r3, %2 \n" // sum00 = __smlad(_val01, _w02, sum00); - "smlad %3, r2, r4, %3 \n" // sum01 = __smlad(_val01, _w02, sum01); - : "=r"(r0), - "=r"(k0), - "=r"(sum0), - "=r"(sum1) - : "0"(r0), - "1"(k0), - "2"(sum0), - "3"(sum1) - : "memory", "r2", "r3", "r4"); - } -#endif // __ARM_FEATURE_SIMD32 - for (; j < inch; j++) - { - signed short val = r0[0]; - - sum0 += val * k0[0]; - sum1 += val * k0[1]; - - r0 += 1; - k0 += 2; - } - - output0_tm[0] = sum0; - output1_tm[0] = sum1; - output0_tm += 1; - output1_tm += 1; - } - } - } -#endif // __ARM_NEON - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - -#if __ARM_NEON - const Mat kernel0_tm = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4); -#else - const Mat kernel0_tm = kernel_tm.channel(p / 2 + p % 2); -#endif - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __ARM_NEON -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - int j = 0; - for (; j + 3 < inch; j += 4) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _val23 = vld1q_s16(r0 + 8); - int16x8_t _val45 = vld1q_s16(r0 + 16); - int16x8_t _val67 = vld1q_s16(r0 + 24); - int16x4_t _w0123 = vld1_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val01), _w0123, 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val01), _w0123, 0); - - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val23), _w0123, 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val23), _w0123, 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val45), _w0123, 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val45), _w0123, 2); - - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val67), _w0123, 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val67), _w0123, 3); - - k0 += 4; - r0 += 32; - } - _sum0 = vaddq_s32(_sum0, _sum2); - _sum1 = vaddq_s32(_sum1, _sum3); - for (; j < inch; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x4_t _w0 = vld1_dup_s16(k0); - - _sum0 = vmlal_s16(_sum0, _w0, vget_low_s16(_val0)); - _sum1 = vmlal_s16(_sum1, _w0, vget_high_s16(_val0)); - - k0 += 1; - r0 += 8; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - output0_tm += 8; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - int j = 0; - for (; j + 3 < inch; j += 4) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _val23 = vld1q_s16(r0 + 8); - int16x4_t _w0123 = vld1_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val01), _w0123, 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val01), _w0123, 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val23), _w0123, 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val23), _w0123, 3); - - k0 += 4; - r0 += 16; - } - _sum0 = vaddq_s32(_sum0, _sum1); - _sum2 = vaddq_s32(_sum2, _sum3); - _sum0 = vaddq_s32(_sum0, _sum2); - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_s16(r0); - int16x4_t _w0 = vld1_dup_s16(k0); - - _sum0 = vmlal_s16(_sum0, _val0, _w0); - - k0 += 1; - r0 += 4; - } - - vst1q_s32(output0_tm, _sum0); - output0_tm += 4; - } -#else - for (; i + 1 < tiles; i += 2) - { - const short* r0 = bb2.row(i / 2); - const short* k0 = kernel0_tm.row(r); - - int sum0 = 0; - int sum1 = 0; - - int j = 0; -#if __ARM_FEATURE_SIMD32 - for (; j + 1 < inch; j += 2) - { - asm volatile( - "ldr r2, [%0], #4 \n" // int16x2_t _val02 = *((int16x2_t*)r0); r0 += 2; - "ldr r3, [%0], #4 \n" // int16x2_t _val13 = *((int16x2_t*)r0); r0 += 2; - "ldr r4, [%1], #4 \n" // int16x2_t _w01 = *((int16x2_t*)k0); k0 += 2; - "smlad %2, r2, r4, %2 \n" // sum00 = __smlad(_val02, _w01, sum00); - "smlad %3, r3, r4, %3 \n" // sum01 = __smlad(_val13, _w01, sum01); - : "=r"(r0), - "=r"(k0), - "=r"(sum0), - "=r"(sum1) - : "0"(r0), - "1"(k0), - "2"(sum0), - "3"(sum1) - : "memory", "r2", "r3", "r4"); - } -#endif // __ARM_FEATURE_SIMD32 - for (; j < inch; j++) - { - signed short val0 = r0[0]; - signed short val1 = r0[1]; - signed short w = k0[0]; - - sum0 += val0 * w; - sum1 += val1 * w; - - k0 += 1; - r0 += 2; - } - - output0_tm[0] = sum0; - output0_tm[1] = sum1; - output0_tm += 2; - } -#endif - for (; i < tiles; i++) - { -#if __ARM_NEON -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int sum = 0; - - int j = 0; -#if __ARM_NEON - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - for (; j + 7 < inch; j += 8) - { - int16x8_t _val = vld1q_s16(r0); - int16x8_t _w = vld1q_s16(k0); - - _sum0 = vmlal_s16(_sum0, vget_low_s16(_val), vget_low_s16(_w)); - _sum1 = vmlal_s16(_sum1, vget_high_s16(_val), vget_high_s16(_w)); - - k0 += 8; - r0 += 8; - } - _sum0 = vaddq_s32(_sum0, _sum1); - for (; j + 3 < inch; j += 4) - { - int16x4_t _val = vld1_s16(r0); - int16x4_t _w = vld1_s16(k0); - - _sum0 = vmlal_s16(_sum0, _val, _w); - - k0 += 4; - r0 += 4; - } -#if __aarch64__ - sum = vaddvq_s32(_sum0); -#else - int32x2_t _ss = vadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); - _ss = vpadd_s32(_ss, _ss); - - sum = vget_lane_s32(_ss, 0); -#endif -#endif // __ARM_NEON -#if __ARM_FEATURE_SIMD32 - for (; j + 1 < inch; j += 2) - { - asm volatile( - "ldr r2, [%0], #4 \n" // int16x2_t _val = *((int16x2_t*)r0); r0 += 2; - "ldr r3, [%1], #4 \n" // int16x2_t _w = *((int16x2_t*)k0); k0 += 2; - "smlad %2, r2, r3, %2 \n" // sum = __smlad(_val, _w, sum); - : "=r"(r0), - "=r"(k0), - "=r"(sum) - : "0"(r0), - "1"(k0), - "2"(sum) - : "memory", "r2", "r3"); - } -#endif // __ARM_FEATURE_SIMD32 - for (; j < inch; j++) - { - signed short val = r0[0]; - signed short w = k0[0]; - - sum += val * w; - - k0 += 1; - r0 += 1; - } - - output0_tm[0] = sum; - output0_tm++; - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_dot_pack8to1_int8.h b/src/layer/arm/convolution_winograd_dot_pack8to1_int8.h deleted file mode 100644 index 6192be12846..00000000000 --- a/src/layer/arm/convolution_winograd_dot_pack8to1_int8.h +++ /dev/null @@ -1,774 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void convolution_winograd_dot_pack8to1_int8_neon(Mat& bottom_blob_tm, int outch, const Mat& kernel_tm, Mat& top_blob_tm, const Option& opt) -{ - // Mat bottom_blob_tm(tiles, 16/36/64, inch, 16u, 8, opt.workspace_allocator); - - const int tiles = bottom_blob_tm.w; - const int batch = bottom_blob_tm.h; - const int inch = bottom_blob_tm.c; - - // permute - Mat bottom_blob_tm2; -#if __aarch64__ - if (tiles >= 8) - bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + tiles % 4, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + tiles % 4, batch, 16u, 8, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 8, opt.workspace_allocator); -#else - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + tiles % 4, batch, 16u, 8, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 8, opt.workspace_allocator); -#endif // __aarch64__ - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < batch; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - short* tm2p = tm2.row(i / 8); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - // transpose 8x8 - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" - "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0] \n" - "sub %0, %0, #64 \n" - - "uzp1 v16.8h, v0.8h, v4.8h \n" - "uzp2 v20.8h, v0.8h, v4.8h \n" - "uzp1 v17.8h, v1.8h, v5.8h \n" - "uzp2 v21.8h, v1.8h, v5.8h \n" - "uzp1 v18.8h, v2.8h, v6.8h \n" - "uzp2 v22.8h, v2.8h, v6.8h \n" - "uzp1 v19.8h, v3.8h, v7.8h \n" - "uzp2 v23.8h, v3.8h, v7.8h \n" - - "st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%1], #64 \n" - "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); - - r0 += bottom_blob_tm.cstep * 8; - } - } -#endif - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - short* tm2p = tm2.row(i / 8 + (i % 8) / 4); -#else - short* tm2p = tm2.row(i / 4); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - // transpose 8x4 -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n" - "st4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "v0", "v1", "v2", "v3"); -#else - asm volatile( - "pld [%0, #512] \n" - "vldm %0, {d0-d7} \n" - "vswp d1, d2 \n" - "vswp d5, d6 \n" - "vswp q1, q2 \n" - "vst4.s16 {d0-d3}, [%1 :64]! \n" - "vst4.s16 {d4-d7}, [%1 :64]! \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "q0", "q1", "q2", "q3"); -#endif // __aarch64__ - r0 += bottom_blob_tm.cstep * 8; - } - } - for (; i < tiles; i++) - { -#if __aarch64__ - short* tm2p = tm2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - short* tm2p = tm2.row(i / 4 + i % 4); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #128] \n" - "ld1 {v0.8h}, [%0] \n" - "st1 {v0.8h}, [%1], #16 \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "v0"); -#else - asm volatile( - "pld [%0, #128] \n" - "vld1.s16 {d0-d1}, [%0 :64] \n" - "vst1.s16 {d0-d1}, [%1 :64]! \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "q0"); -#endif // __aarch64__ - r0 += bottom_blob_tm.cstep * 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, batch, outch, 4u, 1, opt.workspace_allocator); - - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 8; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - int* output4_tm = top_blob_tm.channel(p + 4); - int* output5_tm = top_blob_tm.channel(p + 5); - int* output6_tm = top_blob_tm.channel(p + 6); - int* output7_tm = top_blob_tm.channel(p + 7); - - const Mat kernel01_tm = kernel_tm.channel(p / 8); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - const short* kptr = kernel01_tm.row(r); - - int nn = inch; // inch always > 0 - - asm volatile( - "eor v16.16b, v16.16b, v16.16b \n" - "eor v17.16b, v17.16b, v17.16b \n" - "eor v18.16b, v18.16b, v18.16b \n" - "eor v19.16b, v19.16b, v19.16b \n" - "eor v20.16b, v20.16b, v20.16b \n" - "eor v21.16b, v21.16b, v21.16b \n" - "eor v22.16b, v22.16b, v22.16b \n" - "eor v23.16b, v23.16b, v23.16b \n" - "eor v24.16b, v24.16b, v24.16b \n" - "eor v25.16b, v25.16b, v25.16b \n" - "eor v26.16b, v26.16b, v26.16b \n" - "eor v27.16b, v27.16b, v27.16b \n" - "eor v28.16b, v28.16b, v28.16b \n" - "eor v29.16b, v29.16b, v29.16b \n" - "eor v30.16b, v30.16b, v30.16b \n" - "eor v31.16b, v31.16b, v31.16b \n" - - "0: \n" - - "prfm pldl1keep, [%9, #512] \n" - "ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [%9], #64 \n" - - "prfm pldl1keep, [%10, #512] \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%10], #64 \n" - - "smlal v16.4s, v8.4h, v0.h[0] \n" - "smlal2 v17.4s, v8.8h, v0.h[0] \n" - "smlal v18.4s, v8.4h, v0.h[1] \n" - "smlal2 v19.4s, v8.8h, v0.h[1] \n" - "smlal v20.4s, v8.4h, v0.h[2] \n" - "smlal2 v21.4s, v8.8h, v0.h[2] \n" - "smlal v22.4s, v8.4h, v0.h[3] \n" - "smlal2 v23.4s, v8.8h, v0.h[3] \n" - "smlal v24.4s, v8.4h, v0.h[4] \n" - "smlal2 v25.4s, v8.8h, v0.h[4] \n" - "smlal v26.4s, v8.4h, v0.h[5] \n" - "smlal2 v27.4s, v8.8h, v0.h[5] \n" - "smlal v28.4s, v8.4h, v0.h[6] \n" - "smlal2 v29.4s, v8.8h, v0.h[6] \n" - "smlal v30.4s, v8.4h, v0.h[7] \n" - "smlal2 v31.4s, v8.8h, v0.h[7] \n" - - "smlal v16.4s, v9.4h, v1.h[0] \n" - "smlal2 v17.4s, v9.8h, v1.h[0] \n" - "smlal v18.4s, v9.4h, v1.h[1] \n" - "smlal2 v19.4s, v9.8h, v1.h[1] \n" - "smlal v20.4s, v9.4h, v1.h[2] \n" - "smlal2 v21.4s, v9.8h, v1.h[2] \n" - "smlal v22.4s, v9.4h, v1.h[3] \n" - "smlal2 v23.4s, v9.8h, v1.h[3] \n" - "smlal v24.4s, v9.4h, v1.h[4] \n" - "smlal2 v25.4s, v9.8h, v1.h[4] \n" - "smlal v26.4s, v9.4h, v1.h[5] \n" - "smlal2 v27.4s, v9.8h, v1.h[5] \n" - "smlal v28.4s, v9.4h, v1.h[6] \n" - "smlal2 v29.4s, v9.8h, v1.h[6] \n" - "smlal v30.4s, v9.4h, v1.h[7] \n" - "smlal2 v31.4s, v9.8h, v1.h[7] \n" - - "prfm pldl1keep, [%9, #512] \n" - "ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [%9], #64 \n" - - "smlal v16.4s, v10.4h, v2.h[0] \n" - "smlal2 v17.4s, v10.8h, v2.h[0] \n" - "smlal v18.4s, v10.4h, v2.h[1] \n" - "smlal2 v19.4s, v10.8h, v2.h[1] \n" - "smlal v20.4s, v10.4h, v2.h[2] \n" - "smlal2 v21.4s, v10.8h, v2.h[2] \n" - "smlal v22.4s, v10.4h, v2.h[3] \n" - "smlal2 v23.4s, v10.8h, v2.h[3] \n" - "smlal v24.4s, v10.4h, v2.h[4] \n" - "smlal2 v25.4s, v10.8h, v2.h[4] \n" - "smlal v26.4s, v10.4h, v2.h[5] \n" - "smlal2 v27.4s, v10.8h, v2.h[5] \n" - "smlal v28.4s, v10.4h, v2.h[6] \n" - "smlal2 v29.4s, v10.8h, v2.h[6] \n" - "smlal v30.4s, v10.4h, v2.h[7] \n" - "smlal2 v31.4s, v10.8h, v2.h[7] \n" - - "prfm pldl1keep, [%10, #512] \n" - "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%10], #64 \n" - - "smlal v16.4s, v11.4h, v3.h[0] \n" - "smlal2 v17.4s, v11.8h, v3.h[0] \n" - "smlal v18.4s, v11.4h, v3.h[1] \n" - "smlal2 v19.4s, v11.8h, v3.h[1] \n" - "smlal v20.4s, v11.4h, v3.h[2] \n" - "smlal2 v21.4s, v11.8h, v3.h[2] \n" - "smlal v22.4s, v11.4h, v3.h[3] \n" - "smlal2 v23.4s, v11.8h, v3.h[3] \n" - "smlal v24.4s, v11.4h, v3.h[4] \n" - "smlal2 v25.4s, v11.8h, v3.h[4] \n" - "smlal v26.4s, v11.4h, v3.h[5] \n" - "smlal2 v27.4s, v11.8h, v3.h[5] \n" - "smlal v28.4s, v11.4h, v3.h[6] \n" - "smlal2 v29.4s, v11.8h, v3.h[6] \n" - "smlal v30.4s, v11.4h, v3.h[7] \n" - "smlal2 v31.4s, v11.8h, v3.h[7] \n" - - "smlal v16.4s, v12.4h, v4.h[0] \n" - "smlal2 v17.4s, v12.8h, v4.h[0] \n" - "smlal v18.4s, v12.4h, v4.h[1] \n" - "smlal2 v19.4s, v12.8h, v4.h[1] \n" - "smlal v20.4s, v12.4h, v4.h[2] \n" - "smlal2 v21.4s, v12.8h, v4.h[2] \n" - "smlal v22.4s, v12.4h, v4.h[3] \n" - "smlal2 v23.4s, v12.8h, v4.h[3] \n" - "smlal v24.4s, v12.4h, v4.h[4] \n" - "smlal2 v25.4s, v12.8h, v4.h[4] \n" - "smlal v26.4s, v12.4h, v4.h[5] \n" - "smlal2 v27.4s, v12.8h, v4.h[5] \n" - "smlal v28.4s, v12.4h, v4.h[6] \n" - "smlal2 v29.4s, v12.8h, v4.h[6] \n" - "smlal v30.4s, v12.4h, v4.h[7] \n" - "smlal2 v31.4s, v12.8h, v4.h[7] \n" - - "smlal v16.4s, v13.4h, v5.h[0] \n" - "smlal2 v17.4s, v13.8h, v5.h[0] \n" - "smlal v18.4s, v13.4h, v5.h[1] \n" - "smlal2 v19.4s, v13.8h, v5.h[1] \n" - "smlal v20.4s, v13.4h, v5.h[2] \n" - "smlal2 v21.4s, v13.8h, v5.h[2] \n" - "smlal v22.4s, v13.4h, v5.h[3] \n" - "smlal2 v23.4s, v13.8h, v5.h[3] \n" - "smlal v24.4s, v13.4h, v5.h[4] \n" - "smlal2 v25.4s, v13.8h, v5.h[4] \n" - "smlal v26.4s, v13.4h, v5.h[5] \n" - "smlal2 v27.4s, v13.8h, v5.h[5] \n" - "smlal v28.4s, v13.4h, v5.h[6] \n" - "smlal2 v29.4s, v13.8h, v5.h[6] \n" - "smlal v30.4s, v13.4h, v5.h[7] \n" - "smlal2 v31.4s, v13.8h, v5.h[7] \n" - - "smlal v16.4s, v14.4h, v6.h[0] \n" - "smlal2 v17.4s, v14.8h, v6.h[0] \n" - "smlal v18.4s, v14.4h, v6.h[1] \n" - "smlal2 v19.4s, v14.8h, v6.h[1] \n" - "smlal v20.4s, v14.4h, v6.h[2] \n" - "smlal2 v21.4s, v14.8h, v6.h[2] \n" - "smlal v22.4s, v14.4h, v6.h[3] \n" - "smlal2 v23.4s, v14.8h, v6.h[3] \n" - "smlal v24.4s, v14.4h, v6.h[4] \n" - "smlal2 v25.4s, v14.8h, v6.h[4] \n" - "smlal v26.4s, v14.4h, v6.h[5] \n" - "smlal2 v27.4s, v14.8h, v6.h[5] \n" - "smlal v28.4s, v14.4h, v6.h[6] \n" - "smlal2 v29.4s, v14.8h, v6.h[6] \n" - "smlal v30.4s, v14.4h, v6.h[7] \n" - "smlal2 v31.4s, v14.8h, v6.h[7] \n" - - "subs %w0, %w0, #1 \n" - - "smlal v16.4s, v15.4h, v7.h[0] \n" - "smlal2 v17.4s, v15.8h, v7.h[0] \n" - "smlal v18.4s, v15.4h, v7.h[1] \n" - "smlal2 v19.4s, v15.8h, v7.h[1] \n" - "smlal v20.4s, v15.4h, v7.h[2] \n" - "smlal2 v21.4s, v15.8h, v7.h[2] \n" - "smlal v22.4s, v15.4h, v7.h[3] \n" - "smlal2 v23.4s, v15.8h, v7.h[3] \n" - "smlal v24.4s, v15.4h, v7.h[4] \n" - "smlal2 v25.4s, v15.8h, v7.h[4] \n" - "smlal v26.4s, v15.4h, v7.h[5] \n" - "smlal2 v27.4s, v15.8h, v7.h[5] \n" - "smlal v28.4s, v15.4h, v7.h[6] \n" - "smlal2 v29.4s, v15.8h, v7.h[6] \n" - "smlal v30.4s, v15.4h, v7.h[7] \n" - "smlal2 v31.4s, v15.8h, v7.h[7] \n" - - "bne 0b \n" - - "st1 {v16.4s, v17.4s}, [%1], #32 \n" - "st1 {v18.4s, v19.4s}, [%2], #32 \n" - "st1 {v20.4s, v21.4s}, [%3], #32 \n" - "st1 {v22.4s, v23.4s}, [%4], #32 \n" - "st1 {v24.4s, v25.4s}, [%5], #32 \n" - "st1 {v26.4s, v27.4s}, [%6], #32 \n" - "st1 {v28.4s, v29.4s}, [%7], #32 \n" - "st1 {v30.4s, v31.4s}, [%8], #32 \n" - - : "=r"(nn), // %0 - "=r"(output0_tm), // %1 - "=r"(output1_tm), // %2 - "=r"(output2_tm), // %3 - "=r"(output3_tm), // %4 - "=r"(output4_tm), // %5 - "=r"(output5_tm), // %6 - "=r"(output6_tm), // %7 - "=r"(output7_tm), // %8 - "=r"(r0), // %9 - "=r"(kptr) // %10 - : "0"(nn), - "1"(output0_tm), - "2"(output1_tm), - "3"(output2_tm), - "4"(output3_tm), - "5"(output4_tm), - "6"(output5_tm), - "7"(output6_tm), - "8"(output7_tm), - "9"(r0), - "10"(kptr) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } -#endif - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel01_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val0), vget_low_s16(_w0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val0), vget_low_s16(_w0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val0), vget_low_s16(_w0), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val0), vget_low_s16(_w0), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val0), vget_high_s16(_w0), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val0), vget_high_s16(_w0), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val0), vget_high_s16(_w0), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val0), vget_high_s16(_w0), 3); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val0), vget_low_s16(_w1), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val0), vget_low_s16(_w1), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val0), vget_low_s16(_w1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val0), vget_low_s16(_w1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val0), vget_high_s16(_w1), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val0), vget_high_s16(_w1), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val0), vget_high_s16(_w1), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val0), vget_high_s16(_w1), 3); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val1), vget_low_s16(_w2), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val1), vget_low_s16(_w2), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val1), vget_low_s16(_w2), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val1), vget_low_s16(_w2), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val1), vget_high_s16(_w2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val1), vget_high_s16(_w2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val1), vget_high_s16(_w2), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val1), vget_high_s16(_w2), 3); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val1), vget_low_s16(_w3), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val1), vget_low_s16(_w3), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val1), vget_low_s16(_w3), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val1), vget_low_s16(_w3), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val1), vget_high_s16(_w3), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val1), vget_high_s16(_w3), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val1), vget_high_s16(_w3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val1), vget_high_s16(_w3), 3); - - int16x8_t _w4 = vld1q_s16(k0 + 32); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val2), vget_low_s16(_w4), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val2), vget_low_s16(_w4), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val2), vget_low_s16(_w4), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val2), vget_low_s16(_w4), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val2), vget_high_s16(_w4), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val2), vget_high_s16(_w4), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val2), vget_high_s16(_w4), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val2), vget_high_s16(_w4), 3); - - int16x8_t _w5 = vld1q_s16(k0 + 40); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val2), vget_low_s16(_w5), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val2), vget_low_s16(_w5), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val2), vget_low_s16(_w5), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val2), vget_low_s16(_w5), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val2), vget_high_s16(_w5), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val2), vget_high_s16(_w5), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val2), vget_high_s16(_w5), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val2), vget_high_s16(_w5), 3); - - int16x8_t _w6 = vld1q_s16(k0 + 48); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val3), vget_low_s16(_w6), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val3), vget_low_s16(_w6), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val3), vget_low_s16(_w6), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val3), vget_low_s16(_w6), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val3), vget_high_s16(_w6), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val3), vget_high_s16(_w6), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val3), vget_high_s16(_w6), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val3), vget_high_s16(_w6), 3); - - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val3), vget_low_s16(_w7), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val3), vget_low_s16(_w7), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val3), vget_low_s16(_w7), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val3), vget_low_s16(_w7), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val3), vget_high_s16(_w7), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val3), vget_high_s16(_w7), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val3), vget_high_s16(_w7), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val3), vget_high_s16(_w7), 3); - - r0 += 32; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output2_tm, _sum2); - vst1q_s32(output3_tm, _sum3); - vst1q_s32(output4_tm, _sum4); - vst1q_s32(output5_tm, _sum5); - vst1q_s32(output6_tm, _sum6); - vst1q_s32(output7_tm, _sum7); - - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - output4_tm += 4; - output5_tm += 4; - output6_tm += 4; - output7_tm += 4; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif - const short* k0 = kernel01_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - int16x8_t _w4 = vld1q_s16(k0 + 32); - int16x8_t _w5 = vld1q_s16(k0 + 40); - int16x8_t _w6 = vld1q_s16(k0 + 48); - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val0), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_high_s16(_val0), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_high_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_high_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_high_s16(_val0), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_high_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_high_s16(_val0), 3); - - r0 += 8; - k0 += 64; - } - - output0_tm[0] = vgetq_lane_s32(_sum0, 0); - output1_tm[0] = vgetq_lane_s32(_sum0, 1); - output2_tm[0] = vgetq_lane_s32(_sum0, 2); - output3_tm[0] = vgetq_lane_s32(_sum0, 3); - output4_tm[0] = vgetq_lane_s32(_sum1, 0); - output5_tm[0] = vgetq_lane_s32(_sum1, 1); - output6_tm[0] = vgetq_lane_s32(_sum1, 2); - output7_tm[0] = vgetq_lane_s32(_sum1, 3); - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - output4_tm += 1; - output5_tm += 1; - output6_tm += 1; - output7_tm += 1; - } - } - } - - remain_outch_start += nn_outch << 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p / 8 + p % 8); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - - const short* kptr = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - for (int q = 0; q < inch; q++) - { - int16x8_t _r0 = vld1q_s16(r0); - int16x8_t _r1 = vld1q_s16(r0 + 8); - int16x8_t _r2 = vld1q_s16(r0 + 16); - int16x8_t _r3 = vld1q_s16(r0 + 24); - int16x8_t _r4 = vld1q_s16(r0 + 32); - int16x8_t _r5 = vld1q_s16(r0 + 40); - int16x8_t _r6 = vld1q_s16(r0 + 48); - int16x8_t _r7 = vld1q_s16(r0 + 56); - - int16x8_t _k0 = vld1q_s16(kptr); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r0), vget_low_s16(_k0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r0), vget_low_s16(_k0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_r1), vget_low_s16(_k0), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_r1), vget_low_s16(_k0), 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r2), vget_low_s16(_k0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r2), vget_low_s16(_k0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_r3), vget_low_s16(_k0), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_r3), vget_low_s16(_k0), 3); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r4), vget_high_s16(_k0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r4), vget_high_s16(_k0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_r5), vget_high_s16(_k0), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_r5), vget_high_s16(_k0), 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r6), vget_high_s16(_k0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r6), vget_high_s16(_k0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_r7), vget_high_s16(_k0), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_r7), vget_high_s16(_k0), 3); - - kptr += 8; - r0 += 64; - } - - _sum0 = vaddq_s32(_sum0, _sum2); - _sum1 = vaddq_s32(_sum1, _sum3); - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - - output0_tm += 8; - } -#endif - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* kptr = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int q = 0; q < inch; q++) - { - int16x8_t _r0 = vld1q_s16(r0); - int16x8_t _r1 = vld1q_s16(r0 + 8); - int16x8_t _r2 = vld1q_s16(r0 + 16); - int16x8_t _r3 = vld1q_s16(r0 + 24); - - int16x8_t _k0 = vld1q_s16(kptr); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r0), vget_low_s16(_k0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r0), vget_low_s16(_k0), 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r1), vget_low_s16(_k0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r1), vget_low_s16(_k0), 3); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r2), vget_high_s16(_k0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r2), vget_high_s16(_k0), 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r3), vget_high_s16(_k0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r3), vget_high_s16(_k0), 3); - - kptr += 8; - r0 += 32; - } - - int32x4_t _sum01 = vaddq_s32(_sum0, _sum1); - - vst1q_s32(output0_tm, _sum01); - - output0_tm += 4; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif - const short* kptr = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int q = 0; q < inch; q++) - { - int16x8_t _r0 = vld1q_s16(r0); - - int16x8_t _k0 = vld1q_s16(kptr); - - _sum0 = vmlal_s16(_sum0, vget_low_s16(_r0), vget_low_s16(_k0)); - _sum1 = vmlal_s16(_sum1, vget_high_s16(_r0), vget_high_s16(_k0)); - - kptr += 8; - r0 += 8; - } - - int32x4_t _sum = vaddq_s32(_sum0, _sum1); -#if __aarch64__ - int sum = vaddvq_s32(_sum); // dot -#else - int32x2_t _ss = vadd_s32(vget_low_s32(_sum), vget_high_s32(_sum)); - _ss = vpadd_s32(_ss, _ss); - int sum = vget_lane_s32(_ss, 0); -#endif - - output0_tm[0] = sum; - - output0_tm++; - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_dot_pack8to4_int8.h b/src/layer/arm/convolution_winograd_dot_pack8to4_int8.h deleted file mode 100644 index a17559f6cc2..00000000000 --- a/src/layer/arm/convolution_winograd_dot_pack8to4_int8.h +++ /dev/null @@ -1,1835 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void convolution_winograd_dot_pack8to4_int8_neon(Mat& bottom_blob_tm, int outch, const Mat& kernel_tm, Mat& top_blob_tm, const Option& opt) -{ - // Mat bottom_blob_tm(tiles, 16/36/64, inch, 16u, 8, opt.workspace_allocator); - - const int tiles = bottom_blob_tm.w; - const int batch = bottom_blob_tm.h; - const int inch = bottom_blob_tm.c; - - // permute - Mat bottom_blob_tm2; -#if __aarch64__ - if (tiles >= 12) - bottom_blob_tm2.create(12 * inch, tiles / 12 + (tiles % 12) / 8 + (tiles % 12 % 8) / 4 + (tiles % 12 % 4) / 2 + tiles % 12 % 2, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 8) - bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + (tiles % 4) / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 8, opt.workspace_allocator); -#else - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 8, opt.workspace_allocator); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < batch; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __aarch64__ - for (; i + 11 < tiles; i += 12) - { - short* tm2p = tm2.row(i / 12); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - // transpose 12x8 - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" - "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0], #64 \n" - "ld4 {v16.8h, v17.8h, v18.8h, v19.8h}, [%0] \n" - - "sub %0, %0, #128 \n" - - "uzp1 v20.8h, v0.8h, v4.8h \n" // 0 - "uzp1 v21.8h, v16.8h, v1.8h \n" // 1 - "uzp1 v22.8h, v5.8h, v17.8h \n" // 2 - "uzp1 v23.8h, v2.8h, v6.8h \n" // 3 - "uzp1 v24.8h, v18.8h, v3.8h \n" // 4 - "uzp1 v25.8h, v7.8h, v19.8h \n" // 5 - "uzp2 v26.8h, v0.8h, v4.8h \n" // 6 - "uzp2 v27.8h, v16.8h, v1.8h \n" // 7 - "uzp2 v28.8h, v5.8h, v17.8h \n" // 8 - "uzp2 v29.8h, v2.8h, v6.8h \n" // 9 - "uzp2 v30.8h, v18.8h, v3.8h \n" // 10 - "uzp2 v31.8h, v7.8h, v19.8h \n" // 11 - - "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" - "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [%1], #64 \n" - "st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - - r0 += bottom_blob_tm.cstep * 8; - } - } - for (; i + 7 < tiles; i += 8) - { - short* tmpptr = tm2.row(i / 12 + (i % 12) / 8); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - // transpose 8x8 - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" - "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0] \n" - "sub %0, %0, #64 \n" - - "uzp1 v16.8h, v0.8h, v4.8h \n" - "uzp2 v20.8h, v0.8h, v4.8h \n" - "uzp1 v17.8h, v1.8h, v5.8h \n" - "uzp2 v21.8h, v1.8h, v5.8h \n" - "uzp1 v18.8h, v2.8h, v6.8h \n" - "uzp2 v22.8h, v2.8h, v6.8h \n" - "uzp1 v19.8h, v3.8h, v7.8h \n" - "uzp2 v23.8h, v3.8h, v7.8h \n" - - "st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%1], #64 \n" - "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); - - r0 += bottom_blob_tm.cstep * 8; - } - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - short* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); -#else - short* tmpptr = tm2.row(i / 4); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n" - "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "v0", "v1", "v2", "v3"); -#else - asm volatile( - "pld [%0, #512] \n" - "vldm %0, {d0-d7} \n" - "vstm %1!, {d0-d7} \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "q0", "q1", "q2", "q3"); -#endif - r0 += bottom_blob_tm.cstep * 8; - } - } - for (; i + 1 < tiles; i += 2) - { -#if __aarch64__ - short* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); -#else - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #256] \n" - "ld1 {v0.8h, v1.8h}, [%0] \n" - "st1 {v0.8h, v1.8h}, [%1], #32 \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "v0", "v1"); -#else - asm volatile( - "pld [%0, #256] \n" - "vld1.s16 {d0-d3}, [%0 :128] \n" - "vst1.s16 {d0-d3}, [%1 :128]! \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "q0", "q1"); -#endif - r0 += bottom_blob_tm.cstep * 8; - } - } - for (; i < tiles; i++) - { -#if __aarch64__ - short* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); -#else - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2 + i % 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #128] \n" - "ld1 {v0.8h}, [%0] \n" - "st1 {v0.8h}, [%1], #16 \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "v0"); -#else - asm volatile( - "pld [%0, #128] \n" - "vld1.s16 {d0-d1}, [%0 :128] \n" - "vst1.s16 {d0-d1}, [%1 :128]! \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "q0"); -#endif - r0 += bottom_blob_tm.cstep * 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, batch, outch, 16u, 4, opt.workspace_allocator); - - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 1; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 2; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - - const Mat kernel0_tm = kernel_tm.channel(p / 2); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 11 < tiles; i += 12) - { - const short* r0 = bb2.row(i / 12); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - asm volatile( - "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r01 - - "eor v8.16b, v8.16b, v8.16b \n" - "eor v9.16b, v9.16b, v9.16b \n" - - "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w01 - - "eor v10.16b, v10.16b, v10.16b \n" - "eor v11.16b, v11.16b, v11.16b \n" - - "prfm pldl1keep, [%3, #256] \n" - - "eor v12.16b, v12.16b, v12.16b \n" - "eor v13.16b, v13.16b, v13.16b \n" - - "prfm pldl1keep, [%4, #256] \n" - - "eor v14.16b, v14.16b, v14.16b \n" - "eor v15.16b, v15.16b, v15.16b \n" - "eor v16.16b, v16.16b, v16.16b \n" - "eor v17.16b, v17.16b, v17.16b \n" - "eor v18.16b, v18.16b, v18.16b \n" - "eor v19.16b, v19.16b, v19.16b \n" - "eor v20.16b, v20.16b, v20.16b \n" - "eor v21.16b, v21.16b, v21.16b \n" - "eor v22.16b, v22.16b, v22.16b \n" - "eor v23.16b, v23.16b, v23.16b \n" - "eor v24.16b, v24.16b, v24.16b \n" - "eor v25.16b, v25.16b, v25.16b \n" - "eor v26.16b, v26.16b, v26.16b \n" - "eor v27.16b, v27.16b, v27.16b \n" - "eor v28.16b, v28.16b, v28.16b \n" - "eor v29.16b, v29.16b, v29.16b \n" - "eor v30.16b, v30.16b, v30.16b \n" - "eor v31.16b, v31.16b, v31.16b \n" - - "0: \n" - - "smlal v8.4s, v4.4h, v0.h[0] \n" - "smlal2 v20.4s, v4.8h, v0.h[0] \n" - "smlal v9.4s, v4.4h, v0.h[1] \n" - "smlal2 v21.4s, v4.8h, v0.h[1] \n" - "smlal v10.4s, v4.4h, v0.h[2] \n" - "smlal2 v22.4s, v4.8h, v0.h[2] \n" - "smlal v11.4s, v4.4h, v0.h[3] \n" - "smlal2 v23.4s, v4.8h, v0.h[3] \n" - "smlal v12.4s, v4.4h, v0.h[4] \n" - "smlal2 v24.4s, v4.8h, v0.h[4] \n" - "smlal v13.4s, v4.4h, v0.h[5] \n" - "smlal2 v25.4s, v4.8h, v0.h[5] \n" - "smlal v14.4s, v4.4h, v0.h[6] \n" - "smlal2 v26.4s, v4.8h, v0.h[6] \n" - "smlal v15.4s, v4.4h, v0.h[7] \n" - "smlal2 v27.4s, v4.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r23 - - "smlal v16.4s, v4.4h, v1.h[0] \n" - "smlal2 v28.4s, v4.8h, v1.h[0] \n" - "smlal v17.4s, v4.4h, v1.h[1] \n" - "smlal2 v29.4s, v4.8h, v1.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v18.4s, v4.4h, v1.h[2] \n" - "smlal2 v30.4s, v4.8h, v1.h[2] \n" - "smlal v19.4s, v4.4h, v1.h[3] \n" - "smlal2 v31.4s, v4.8h, v1.h[3] \n" - - "ld1 {v6.8h, v7.8h}, [%4], #32 \n" // w23 - - "smlal v8.4s, v5.4h, v1.h[4] \n" - "smlal2 v20.4s, v5.8h, v1.h[4] \n" - "smlal v9.4s, v5.4h, v1.h[5] \n" - "smlal2 v21.4s, v5.8h, v1.h[5] \n" - - "prfm pldl1keep, [%4, #256] \n" - - "smlal v10.4s, v5.4h, v1.h[6] \n" - "smlal2 v22.4s, v5.8h, v1.h[6] \n" - "smlal v11.4s, v5.4h, v1.h[7] \n" - "smlal2 v23.4s, v5.8h, v1.h[7] \n" - "smlal v12.4s, v5.4h, v2.h[0] \n" - "smlal2 v24.4s, v5.8h, v2.h[0] \n" - "smlal v13.4s, v5.4h, v2.h[1] \n" - "smlal2 v25.4s, v5.8h, v2.h[1] \n" - "smlal v14.4s, v5.4h, v2.h[2] \n" - "smlal2 v26.4s, v5.8h, v2.h[2] \n" - "smlal v15.4s, v5.4h, v2.h[3] \n" - "smlal2 v27.4s, v5.8h, v2.h[3] \n" - "smlal v16.4s, v5.4h, v2.h[4] \n" - "smlal2 v28.4s, v5.8h, v2.h[4] \n" - "smlal v17.4s, v5.4h, v2.h[5] \n" - "smlal2 v29.4s, v5.8h, v2.h[5] \n" - "smlal v18.4s, v5.4h, v2.h[6] \n" - "smlal2 v30.4s, v5.8h, v2.h[6] \n" - "smlal v19.4s, v5.4h, v2.h[7] \n" - "smlal2 v31.4s, v5.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r45 - - "smlal v8.4s, v6.4h, v3.h[0] \n" - "smlal2 v20.4s, v6.8h, v3.h[0] \n" - "smlal v9.4s, v6.4h, v3.h[1] \n" - "smlal2 v21.4s, v6.8h, v3.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v10.4s, v6.4h, v3.h[2] \n" - "smlal2 v22.4s, v6.8h, v3.h[2] \n" - "smlal v11.4s, v6.4h, v3.h[3] \n" - "smlal2 v23.4s, v6.8h, v3.h[3] \n" - "smlal v12.4s, v6.4h, v3.h[4] \n" - "smlal2 v24.4s, v6.8h, v3.h[4] \n" - "smlal v13.4s, v6.4h, v3.h[5] \n" - "smlal2 v25.4s, v6.8h, v3.h[5] \n" - "smlal v14.4s, v6.4h, v3.h[6] \n" - "smlal2 v26.4s, v6.8h, v3.h[6] \n" - "smlal v15.4s, v6.4h, v3.h[7] \n" - "smlal2 v27.4s, v6.8h, v3.h[7] \n" - - "smlal v16.4s, v6.4h, v0.h[0] \n" - "smlal2 v28.4s, v6.8h, v0.h[0] \n" - "smlal v17.4s, v6.4h, v0.h[1] \n" - "smlal2 v29.4s, v6.8h, v0.h[1] \n" - "smlal v18.4s, v6.4h, v0.h[2] \n" - "smlal2 v30.4s, v6.8h, v0.h[2] \n" - "smlal v19.4s, v6.4h, v0.h[3] \n" - "smlal2 v31.4s, v6.8h, v0.h[3] \n" - - "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w45 - - "smlal v8.4s, v7.4h, v0.h[4] \n" - "smlal2 v20.4s, v7.8h, v0.h[4] \n" - "smlal v9.4s, v7.4h, v0.h[5] \n" - "smlal2 v21.4s, v7.8h, v0.h[5] \n" - - "prfm pldl1keep, [%4, #256] \n" - - "smlal v10.4s, v7.4h, v0.h[6] \n" - "smlal2 v22.4s, v7.8h, v0.h[6] \n" - "smlal v11.4s, v7.4h, v0.h[7] \n" - "smlal2 v23.4s, v7.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r67 - - "smlal v12.4s, v7.4h, v1.h[0] \n" - "smlal2 v24.4s, v7.8h, v1.h[0] \n" - "smlal v13.4s, v7.4h, v1.h[1] \n" - "smlal2 v25.4s, v7.8h, v1.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v14.4s, v7.4h, v1.h[2] \n" - "smlal2 v26.4s, v7.8h, v1.h[2] \n" - "smlal v15.4s, v7.4h, v1.h[3] \n" - "smlal2 v27.4s, v7.8h, v1.h[3] \n" - "smlal v16.4s, v7.4h, v1.h[4] \n" - "smlal2 v28.4s, v7.8h, v1.h[4] \n" - "smlal v17.4s, v7.4h, v1.h[5] \n" - "smlal2 v29.4s, v7.8h, v1.h[5] \n" - "smlal v18.4s, v7.4h, v1.h[6] \n" - "smlal2 v30.4s, v7.8h, v1.h[6] \n" - "smlal v19.4s, v7.4h, v1.h[7] \n" - "smlal2 v31.4s, v7.8h, v1.h[7] \n" - - "smlal v8.4s, v4.4h, v2.h[0] \n" - "smlal2 v20.4s, v4.8h, v2.h[0] \n" - "smlal v9.4s, v4.4h, v2.h[1] \n" - "smlal2 v21.4s, v4.8h, v2.h[1] \n" - "smlal v10.4s, v4.4h, v2.h[2] \n" - "smlal2 v22.4s, v4.8h, v2.h[2] \n" - "smlal v11.4s, v4.4h, v2.h[3] \n" - "smlal2 v23.4s, v4.8h, v2.h[3] \n" - "smlal v12.4s, v4.4h, v2.h[4] \n" - "smlal2 v24.4s, v4.8h, v2.h[4] \n" - "smlal v13.4s, v4.4h, v2.h[5] \n" - "smlal2 v25.4s, v4.8h, v2.h[5] \n" - "smlal v14.4s, v4.4h, v2.h[6] \n" - "smlal2 v26.4s, v4.8h, v2.h[6] \n" - "smlal v15.4s, v4.4h, v2.h[7] \n" - "smlal2 v27.4s, v4.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r89 - - "smlal v16.4s, v4.4h, v3.h[0] \n" - "smlal2 v28.4s, v4.8h, v3.h[0] \n" - "smlal v17.4s, v4.4h, v3.h[1] \n" - "smlal2 v29.4s, v4.8h, v3.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v18.4s, v4.4h, v3.h[2] \n" - "smlal2 v30.4s, v4.8h, v3.h[2] \n" - "smlal v19.4s, v4.4h, v3.h[3] \n" - "smlal2 v31.4s, v4.8h, v3.h[3] \n" - - "ld1 {v6.8h, v7.8h}, [%4], #32 \n" // w67 - - "smlal v8.4s, v5.4h, v3.h[4] \n" - "smlal2 v20.4s, v5.8h, v3.h[4] \n" - "smlal v9.4s, v5.4h, v3.h[5] \n" - "smlal2 v21.4s, v5.8h, v3.h[5] \n" - - "prfm pldl1keep, [%4, #256] \n" - - "smlal v10.4s, v5.4h, v3.h[6] \n" - "smlal2 v22.4s, v5.8h, v3.h[6] \n" - "smlal v11.4s, v5.4h, v3.h[7] \n" - "smlal2 v23.4s, v5.8h, v3.h[7] \n" - - "smlal v12.4s, v5.4h, v0.h[0] \n" - "smlal2 v24.4s, v5.8h, v0.h[0] \n" - "smlal v13.4s, v5.4h, v0.h[1] \n" - "smlal2 v25.4s, v5.8h, v0.h[1] \n" - "smlal v14.4s, v5.4h, v0.h[2] \n" - "smlal2 v26.4s, v5.8h, v0.h[2] \n" - "smlal v15.4s, v5.4h, v0.h[3] \n" - "smlal2 v27.4s, v5.8h, v0.h[3] \n" - "smlal v16.4s, v5.4h, v0.h[4] \n" - "smlal2 v28.4s, v5.8h, v0.h[4] \n" - "smlal v17.4s, v5.4h, v0.h[5] \n" - "smlal2 v29.4s, v5.8h, v0.h[5] \n" - "smlal v18.4s, v5.4h, v0.h[6] \n" - "smlal2 v30.4s, v5.8h, v0.h[6] \n" - "smlal v19.4s, v5.4h, v0.h[7] \n" - "smlal2 v31.4s, v5.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r1011 - - "smlal v8.4s, v6.4h, v1.h[0] \n" - "smlal2 v20.4s, v6.8h, v1.h[0] \n" - "smlal v9.4s, v6.4h, v1.h[1] \n" - "smlal2 v21.4s, v6.8h, v1.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v10.4s, v6.4h, v1.h[2] \n" - "smlal2 v22.4s, v6.8h, v1.h[2] \n" - "smlal v11.4s, v6.4h, v1.h[3] \n" - "smlal2 v23.4s, v6.8h, v1.h[3] \n" - "smlal v12.4s, v6.4h, v1.h[4] \n" - "smlal2 v24.4s, v6.8h, v1.h[4] \n" - "smlal v13.4s, v6.4h, v1.h[5] \n" - "smlal2 v25.4s, v6.8h, v1.h[5] \n" - "smlal v14.4s, v6.4h, v1.h[6] \n" - "smlal2 v26.4s, v6.8h, v1.h[6] \n" - "smlal v15.4s, v6.4h, v1.h[7] \n" - "smlal2 v27.4s, v6.8h, v1.h[7] \n" - "smlal v16.4s, v6.4h, v2.h[0] \n" - "smlal2 v28.4s, v6.8h, v2.h[0] \n" - "smlal v17.4s, v6.4h, v2.h[1] \n" - "smlal2 v29.4s, v6.8h, v2.h[1] \n" - "smlal v18.4s, v6.4h, v2.h[2] \n" - "smlal2 v30.4s, v6.8h, v2.h[2] \n" - "smlal v19.4s, v6.4h, v2.h[3] \n" - "smlal2 v31.4s, v6.8h, v2.h[3] \n" - - "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w01 - - "smlal v8.4s, v7.4h, v2.h[4] \n" - "smlal2 v20.4s, v7.8h, v2.h[4] \n" - "smlal v9.4s, v7.4h, v2.h[5] \n" - "smlal2 v21.4s, v7.8h, v2.h[5] \n" - - "prfm pldl1keep, [%4, #256] \n" - - "smlal v10.4s, v7.4h, v2.h[6] \n" - "smlal2 v22.4s, v7.8h, v2.h[6] \n" - "smlal v11.4s, v7.4h, v2.h[7] \n" - "smlal2 v23.4s, v7.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r01 - - "smlal v12.4s, v7.4h, v3.h[0] \n" - "smlal2 v24.4s, v7.8h, v3.h[0] \n" - "smlal v13.4s, v7.4h, v3.h[1] \n" - "smlal2 v25.4s, v7.8h, v3.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v14.4s, v7.4h, v3.h[2] \n" - "smlal2 v26.4s, v7.8h, v3.h[2] \n" - "smlal v15.4s, v7.4h, v3.h[3] \n" - "smlal2 v27.4s, v7.8h, v3.h[3] \n" - "smlal v16.4s, v7.4h, v3.h[4] \n" - "smlal2 v28.4s, v7.8h, v3.h[4] \n" - "smlal v17.4s, v7.4h, v3.h[5] \n" - "smlal2 v29.4s, v7.8h, v3.h[5] \n" - - "subs %w0, %w0, #1 \n" - - "smlal v18.4s, v7.4h, v3.h[6] \n" - "smlal2 v30.4s, v7.8h, v3.h[6] \n" - "smlal v19.4s, v7.4h, v3.h[7] \n" - "smlal2 v31.4s, v7.8h, v3.h[7] \n" - - "bne 0b \n" - - "sub %3, %3, #32 \n" - "sub %4, %4, #32 \n" - - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n" - "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%2], #64 \n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n" - "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%2], #64 \n" - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n" - "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%2], #64 \n" - - : "=r"(nn), // %0 - "=r"(output0_tm), // %1 - "=r"(output1_tm), // %2 - "=r"(r0), // %3 - "=r"(k0) // %4 - : "0"(nn), - "1"(output0_tm), - "2"(output1_tm), - "3"(r0), - "4"(k0) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 12 + (i % 12) / 8); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - int32x4_t _sum8 = vdupq_n_s32(0); - int32x4_t _sum9 = vdupq_n_s32(0); - int32x4_t _suma = vdupq_n_s32(0); - int32x4_t _sumb = vdupq_n_s32(0); - int32x4_t _sumc = vdupq_n_s32(0); - int32x4_t _sumd = vdupq_n_s32(0); - int32x4_t _sume = vdupq_n_s32(0); - int32x4_t _sumf = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - int16x8_t _val4 = vld1q_s16(r0 + 32); - int16x8_t _val5 = vld1q_s16(r0 + 40); - int16x8_t _val6 = vld1q_s16(r0 + 48); - int16x8_t _val7 = vld1q_s16(r0 + 56); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val0), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val0), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_low_s16(_val0), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_low_s16(_val0), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_low_s16(_val0), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_low_s16(_val0), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w0), vget_high_s16(_val0), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w0), vget_high_s16(_val0), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w0), vget_high_s16(_val0), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w0), vget_high_s16(_val0), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w0), vget_high_s16(_val0), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w0), vget_high_s16(_val0), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w0), vget_high_s16(_val0), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w0), vget_high_s16(_val0), 3); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val1), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val1), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_low_s16(_val1), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_low_s16(_val1), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_low_s16(_val1), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_low_s16(_val1), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w1), vget_high_s16(_val1), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w1), vget_high_s16(_val1), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w1), vget_high_s16(_val1), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w1), vget_high_s16(_val1), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w1), vget_high_s16(_val1), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w1), vget_high_s16(_val1), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w1), vget_high_s16(_val1), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w1), vget_high_s16(_val1), 3); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val2), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val2), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val2), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val2), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_low_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_low_s16(_val2), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_low_s16(_val2), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_low_s16(_val2), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w2), vget_high_s16(_val2), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w2), vget_high_s16(_val2), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w2), vget_high_s16(_val2), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w2), vget_high_s16(_val2), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w2), vget_high_s16(_val2), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w2), vget_high_s16(_val2), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w2), vget_high_s16(_val2), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w2), vget_high_s16(_val2), 3); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val3), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val3), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val3), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val3), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_low_s16(_val3), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_low_s16(_val3), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_low_s16(_val3), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_low_s16(_val3), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w3), vget_high_s16(_val3), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w3), vget_high_s16(_val3), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w3), vget_high_s16(_val3), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w3), vget_high_s16(_val3), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w3), vget_high_s16(_val3), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w3), vget_high_s16(_val3), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w3), vget_high_s16(_val3), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w3), vget_high_s16(_val3), 3); - - int16x8_t _w4 = vld1q_s16(k0 + 32); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_low_s16(_val4), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_low_s16(_val4), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w4), vget_low_s16(_val4), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w4), vget_low_s16(_val4), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w4), vget_low_s16(_val4), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w4), vget_low_s16(_val4), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w4), vget_low_s16(_val4), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w4), vget_low_s16(_val4), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w4), vget_high_s16(_val4), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w4), vget_high_s16(_val4), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w4), vget_high_s16(_val4), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w4), vget_high_s16(_val4), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w4), vget_high_s16(_val4), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w4), vget_high_s16(_val4), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w4), vget_high_s16(_val4), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w4), vget_high_s16(_val4), 3); - - int16x8_t _w5 = vld1q_s16(k0 + 40); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_low_s16(_val5), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_low_s16(_val5), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w5), vget_low_s16(_val5), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w5), vget_low_s16(_val5), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w5), vget_low_s16(_val5), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w5), vget_low_s16(_val5), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w5), vget_low_s16(_val5), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w5), vget_low_s16(_val5), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w5), vget_high_s16(_val5), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w5), vget_high_s16(_val5), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w5), vget_high_s16(_val5), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w5), vget_high_s16(_val5), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w5), vget_high_s16(_val5), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w5), vget_high_s16(_val5), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w5), vget_high_s16(_val5), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w5), vget_high_s16(_val5), 3); - - int16x8_t _w6 = vld1q_s16(k0 + 48); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_low_s16(_val6), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_low_s16(_val6), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w6), vget_low_s16(_val6), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w6), vget_low_s16(_val6), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w6), vget_low_s16(_val6), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w6), vget_low_s16(_val6), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w6), vget_low_s16(_val6), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w6), vget_low_s16(_val6), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w6), vget_high_s16(_val6), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w6), vget_high_s16(_val6), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w6), vget_high_s16(_val6), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w6), vget_high_s16(_val6), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w6), vget_high_s16(_val6), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w6), vget_high_s16(_val6), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w6), vget_high_s16(_val6), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w6), vget_high_s16(_val6), 3); - - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_low_s16(_val7), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_low_s16(_val7), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w7), vget_low_s16(_val7), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w7), vget_low_s16(_val7), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w7), vget_low_s16(_val7), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w7), vget_low_s16(_val7), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w7), vget_low_s16(_val7), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w7), vget_low_s16(_val7), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w7), vget_high_s16(_val7), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w7), vget_high_s16(_val7), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w7), vget_high_s16(_val7), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w7), vget_high_s16(_val7), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w7), vget_high_s16(_val7), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w7), vget_high_s16(_val7), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w7), vget_high_s16(_val7), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w7), vget_high_s16(_val7), 3); - - r0 += 64; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output0_tm + 4, _sum2); - vst1q_s32(output1_tm + 4, _sum3); - vst1q_s32(output0_tm + 8, _sum4); - vst1q_s32(output1_tm + 8, _sum5); - vst1q_s32(output0_tm + 12, _sum6); - vst1q_s32(output1_tm + 12, _sum7); - vst1q_s32(output0_tm + 16, _sum8); - vst1q_s32(output1_tm + 16, _sum9); - vst1q_s32(output0_tm + 20, _suma); - vst1q_s32(output1_tm + 20, _sumb); - vst1q_s32(output0_tm + 24, _sumc); - vst1q_s32(output1_tm + 24, _sumd); - vst1q_s32(output0_tm + 28, _sume); - vst1q_s32(output1_tm + 28, _sumf); - output0_tm += 32; - output1_tm += 32; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __aarch64__ - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 0); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_low_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_low_s16(_val2), 0); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_low_s16(_val3), 0); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_low_s16(_val3), 0); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_low_s16(_val2), 1); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_low_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_low_s16(_val3), 1); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_low_s16(_val3), 1); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val1), 2); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_low_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_low_s16(_val2), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_low_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_low_s16(_val3), 2); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val1), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_low_s16(_val2), 3); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_low_s16(_val2), 3); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_low_s16(_val3), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_low_s16(_val3), 3); - - int16x8_t _w4 = vld1q_s16(k0 + 32); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_high_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w4), vget_high_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w4), vget_high_s16(_val1), 0); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w4), vget_high_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w4), vget_high_s16(_val2), 0); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w4), vget_high_s16(_val3), 0); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w4), vget_high_s16(_val3), 0); - - int16x8_t _w5 = vld1q_s16(k0 + 40); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_high_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_high_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w5), vget_high_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w5), vget_high_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w5), vget_high_s16(_val2), 1); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w5), vget_high_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w5), vget_high_s16(_val3), 1); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w5), vget_high_s16(_val3), 1); - - int16x8_t _w6 = vld1q_s16(k0 + 48); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_high_s16(_val0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w6), vget_high_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w6), vget_high_s16(_val1), 2); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w6), vget_high_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w6), vget_high_s16(_val2), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w6), vget_high_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w6), vget_high_s16(_val3), 2); - - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_high_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_high_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w7), vget_high_s16(_val1), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w7), vget_high_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w7), vget_high_s16(_val2), 3); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w7), vget_high_s16(_val2), 3); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w7), vget_high_s16(_val3), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w7), vget_high_s16(_val3), 3); - - r0 += 32; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output0_tm + 4, _sum2); - vst1q_s32(output1_tm + 4, _sum3); - vst1q_s32(output0_tm + 8, _sum4); - vst1q_s32(output1_tm + 8, _sum5); - vst1q_s32(output0_tm + 12, _sum6); - vst1q_s32(output1_tm + 12, _sum7); - output0_tm += 16; - output1_tm += 16; -#else - asm volatile( - "veor q8, q8 \n" - "veor q9, q9 \n" - "veor q10, q10 \n" - "veor q11, q11 \n" - "veor q12, q12 \n" - "veor q13, q13 \n" - "veor q14, q14 \n" - "veor q15, q15 \n" - - "0: \n" - - "pld [%3, #256] \n" - "pld [%3, #512] \n" - "vldm %3!, {d0-d7} \n" - - "pld [%4, #256] \n" - "vld1.s16 {d8-d11}, [%4 :128]! \n" - - "vmlal.s16 q8, d8, d0[0] \n" - "vmlal.s16 q12, d9, d0[0] \n" - "vmlal.s16 q9, d8, d2[0] \n" - "vmlal.s16 q13, d9, d2[0] \n" - "vmlal.s16 q10, d8, d4[0] \n" - "vmlal.s16 q14, d9, d4[0] \n" - "vmlal.s16 q11, d8, d6[0] \n" - "vmlal.s16 q15, d9, d6[0] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d8-d9}, [%4 :128]! \n" - - "vmlal.s16 q8, d10, d0[1] \n" - "vmlal.s16 q12, d11, d0[1] \n" - "vmlal.s16 q9, d10, d2[1] \n" - "vmlal.s16 q13, d11, d2[1] \n" - "vmlal.s16 q10, d10, d4[1] \n" - "vmlal.s16 q14, d11, d4[1] \n" - "vmlal.s16 q11, d10, d6[1] \n" - "vmlal.s16 q15, d11, d6[1] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d10-d11}, [%4 :128]! \n" - - "vmlal.s16 q8, d8, d0[2] \n" - "vmlal.s16 q12, d9, d0[2] \n" - "vmlal.s16 q9, d8, d2[2] \n" - "vmlal.s16 q13, d9, d2[2] \n" - "vmlal.s16 q10, d8, d4[2] \n" - "vmlal.s16 q14, d9, d4[2] \n" - "vmlal.s16 q11, d8, d6[2] \n" - "vmlal.s16 q15, d9, d6[2] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d8-d9}, [%4 :128]! \n" - - "vmlal.s16 q8, d10, d0[3] \n" - "vmlal.s16 q12, d11, d0[3] \n" - "vmlal.s16 q9, d10, d2[3] \n" - "vmlal.s16 q13, d11, d2[3] \n" - "vmlal.s16 q10, d10, d4[3] \n" - "vmlal.s16 q14, d11, d4[3] \n" - "vmlal.s16 q11, d10, d6[3] \n" - "vmlal.s16 q15, d11, d6[3] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d10-d11}, [%4 :128]! \n" - - "vmlal.s16 q8, d8, d1[0] \n" - "vmlal.s16 q12, d9, d1[0] \n" - "vmlal.s16 q9, d8, d3[0] \n" - "vmlal.s16 q13, d9, d3[0] \n" - "vmlal.s16 q10, d8, d5[0] \n" - "vmlal.s16 q14, d9, d5[0] \n" - "vmlal.s16 q11, d8, d7[0] \n" - "vmlal.s16 q15, d9, d7[0] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d8-d9}, [%4 :128]! \n" - - "vmlal.s16 q8, d10, d1[1] \n" - "vmlal.s16 q12, d11, d1[1] \n" - "vmlal.s16 q9, d10, d3[1] \n" - "vmlal.s16 q13, d11, d3[1] \n" - "vmlal.s16 q10, d10, d5[1] \n" - "vmlal.s16 q14, d11, d5[1] \n" - "vmlal.s16 q11, d10, d7[1] \n" - "vmlal.s16 q15, d11, d7[1] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d10-d11}, [%4 :128]! \n" - - "vmlal.s16 q8, d8, d1[2] \n" - "vmlal.s16 q12, d9, d1[2] \n" - "vmlal.s16 q9, d8, d3[2] \n" - "vmlal.s16 q13, d9, d3[2] \n" - "vmlal.s16 q10, d8, d5[2] \n" - "vmlal.s16 q14, d9, d5[2] \n" - "vmlal.s16 q11, d8, d7[2] \n" - "vmlal.s16 q15, d9, d7[2] \n" - - "subs %0, %0, #1 \n" - - "vmlal.s16 q8, d10, d1[3] \n" - "vmlal.s16 q12, d11, d1[3] \n" - "vmlal.s16 q9, d10, d3[3] \n" - "vmlal.s16 q13, d11, d3[3] \n" - "vmlal.s16 q10, d10, d5[3] \n" - "vmlal.s16 q14, d11, d5[3] \n" - "vmlal.s16 q11, d10, d7[3] \n" - "vmlal.s16 q15, d11, d7[3] \n" - - "bne 0b \n" - - "vstm %1!, {d16-d23} \n" - "vstm %2!, {d24-d31} \n" - - : "=r"(nn), - "=r"(output0_tm), - "=r"(output1_tm), - "=r"(r0), - "=r"(k0) - : "0"(nn), - "1"(output0_tm), - "2"(output1_tm), - "3"(r0), - "4"(k0) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); -#endif - } - for (; i + 1 < tiles; i += 2) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); -#else - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - int16x8_t _w4 = vld1q_s16(k0 + 32); - int16x8_t _w5 = vld1q_s16(k0 + 40); - int16x8_t _w6 = vld1q_s16(k0 + 48); - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val1), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val1), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val1), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_high_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w4), vget_high_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w4), vget_high_s16(_val1), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_high_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_high_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w5), vget_high_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w5), vget_high_s16(_val1), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_high_s16(_val0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w6), vget_high_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w6), vget_high_s16(_val1), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_high_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_high_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w7), vget_high_s16(_val1), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w7), vget_high_s16(_val1), 3); - - r0 += 16; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output0_tm + 4, _sum2); - vst1q_s32(output1_tm + 4, _sum3); - output0_tm += 8; - output1_tm += 8; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); -#else - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - int16x8_t _w4 = vld1q_s16(k0 + 32); - int16x8_t _w5 = vld1q_s16(k0 + 40); - int16x8_t _w6 = vld1q_s16(k0 + 48); - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val0), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_high_s16(_val0), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_high_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_high_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_high_s16(_val0), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_high_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_high_s16(_val0), 3); - - r0 += 8; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - output0_tm += 4; - output1_tm += 4; - } - } - } - - remain_outch_start += nn_outch << 1; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p / 2 + p % 2); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 11 < tiles; i += 12) - { - const short* r0 = bb2.row(i / 12); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - asm volatile( - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r01 - - "eor v8.16b, v8.16b, v8.16b \n" - "eor v9.16b, v9.16b, v9.16b \n" - - "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w01 - - "eor v10.16b, v10.16b, v10.16b \n" - "eor v11.16b, v11.16b, v11.16b \n" - - "prfm pldl1keep, [%2, #256] \n" - - "eor v12.16b, v12.16b, v12.16b \n" - "eor v13.16b, v13.16b, v13.16b \n" - - "prfm pldl1keep, [%3, #256] \n" - - "eor v14.16b, v14.16b, v14.16b \n" - "eor v15.16b, v15.16b, v15.16b \n" - "eor v16.16b, v16.16b, v16.16b \n" - "eor v17.16b, v17.16b, v17.16b \n" - "eor v18.16b, v18.16b, v18.16b \n" - "eor v19.16b, v19.16b, v19.16b \n" - - "0: \n" - - "smlal v8.4s, v4.4h, v0.h[0] \n" - "smlal v9.4s, v4.4h, v0.h[1] \n" - "smlal v10.4s, v4.4h, v0.h[2] \n" - "smlal v11.4s, v4.4h, v0.h[3] \n" - "smlal v12.4s, v4.4h, v0.h[4] \n" - "smlal v13.4s, v4.4h, v0.h[5] \n" - "smlal v14.4s, v4.4h, v0.h[6] \n" - "smlal v15.4s, v4.4h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r23 - - "smlal v16.4s, v4.4h, v1.h[0] \n" - "smlal v17.4s, v4.4h, v1.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal v18.4s, v4.4h, v1.h[2] \n" - "smlal v19.4s, v4.4h, v1.h[3] \n" - - "smlal2 v8.4s, v4.8h, v1.h[4] \n" - "smlal2 v9.4s, v4.8h, v1.h[5] \n" - "smlal2 v10.4s, v4.8h, v1.h[6] \n" - "smlal2 v11.4s, v4.8h, v1.h[7] \n" - "smlal2 v12.4s, v4.8h, v2.h[0] \n" - "smlal2 v13.4s, v4.8h, v2.h[1] \n" - "smlal2 v14.4s, v4.8h, v2.h[2] \n" - "smlal2 v15.4s, v4.8h, v2.h[3] \n" - "smlal2 v16.4s, v4.8h, v2.h[4] \n" - "smlal2 v17.4s, v4.8h, v2.h[5] \n" - "smlal2 v18.4s, v4.8h, v2.h[6] \n" - "smlal2 v19.4s, v4.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r45 - - "smlal v8.4s, v5.4h, v3.h[0] \n" - "smlal v9.4s, v5.4h, v3.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal v10.4s, v5.4h, v3.h[2] \n" - "smlal v11.4s, v5.4h, v3.h[3] \n" - "smlal v12.4s, v5.4h, v3.h[4] \n" - "smlal v13.4s, v5.4h, v3.h[5] \n" - "smlal v14.4s, v5.4h, v3.h[6] \n" - "smlal v15.4s, v5.4h, v3.h[7] \n" - "smlal v16.4s, v5.4h, v0.h[0] \n" - "smlal v17.4s, v5.4h, v0.h[1] \n" - "smlal v18.4s, v5.4h, v0.h[2] \n" - "smlal v19.4s, v5.4h, v0.h[3] \n" - - "ld1 {v6.8h, v7.8h}, [%3], #32 \n" // w23 - - "smlal2 v8.4s, v5.8h, v0.h[4] \n" - "smlal2 v9.4s, v5.8h, v0.h[5] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal2 v10.4s, v5.8h, v0.h[6] \n" - "smlal2 v11.4s, v5.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r67 - - "smlal2 v12.4s, v5.8h, v1.h[0] \n" - "smlal2 v13.4s, v5.8h, v1.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal2 v14.4s, v5.8h, v1.h[2] \n" - "smlal2 v15.4s, v5.8h, v1.h[3] \n" - "smlal2 v16.4s, v5.8h, v1.h[4] \n" - "smlal2 v17.4s, v5.8h, v1.h[5] \n" - "smlal2 v18.4s, v5.8h, v1.h[6] \n" - "smlal2 v19.4s, v5.8h, v1.h[7] \n" - - "smlal v8.4s, v6.4h, v2.h[0] \n" - "smlal v9.4s, v6.4h, v2.h[1] \n" - "smlal v10.4s, v6.4h, v2.h[2] \n" - "smlal v11.4s, v6.4h, v2.h[3] \n" - "smlal v12.4s, v6.4h, v2.h[4] \n" - "smlal v13.4s, v6.4h, v2.h[5] \n" - "smlal v14.4s, v6.4h, v2.h[6] \n" - "smlal v15.4s, v6.4h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r89 - - "smlal v16.4s, v6.4h, v3.h[0] \n" - "smlal v17.4s, v6.4h, v3.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal v18.4s, v6.4h, v3.h[2] \n" - "smlal v19.4s, v6.4h, v3.h[3] \n" - - "smlal2 v8.4s, v6.8h, v3.h[4] \n" - "smlal2 v9.4s, v6.8h, v3.h[5] \n" - "smlal2 v10.4s, v6.8h, v3.h[6] \n" - "smlal2 v11.4s, v6.8h, v3.h[7] \n" - "smlal2 v12.4s, v6.8h, v0.h[0] \n" - "smlal2 v13.4s, v6.8h, v0.h[1] \n" - "smlal2 v14.4s, v6.8h, v0.h[2] \n" - "smlal2 v15.4s, v6.8h, v0.h[3] \n" - "smlal2 v16.4s, v6.8h, v0.h[4] \n" - "smlal2 v17.4s, v6.8h, v0.h[5] \n" - "smlal2 v18.4s, v6.8h, v0.h[6] \n" - "smlal2 v19.4s, v6.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r1011 - - "smlal v8.4s, v7.4h, v1.h[0] \n" - "smlal v9.4s, v7.4h, v1.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal v10.4s, v7.4h, v1.h[2] \n" - "smlal v11.4s, v7.4h, v1.h[3] \n" - "smlal v12.4s, v7.4h, v1.h[4] \n" - "smlal v13.4s, v7.4h, v1.h[5] \n" - "smlal v14.4s, v7.4h, v1.h[6] \n" - "smlal v15.4s, v7.4h, v1.h[7] \n" - "smlal v16.4s, v7.4h, v2.h[0] \n" - "smlal v17.4s, v7.4h, v2.h[1] \n" - "smlal v18.4s, v7.4h, v2.h[2] \n" - "smlal v19.4s, v7.4h, v2.h[3] \n" - - "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w01 - - "smlal2 v8.4s, v7.8h, v2.h[4] \n" - "smlal2 v9.4s, v7.8h, v2.h[5] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal2 v10.4s, v7.8h, v2.h[6] \n" - "smlal2 v11.4s, v7.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r01 - - "smlal2 v12.4s, v7.8h, v3.h[0] \n" - "smlal2 v13.4s, v7.8h, v3.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal2 v14.4s, v7.8h, v3.h[2] \n" - "smlal2 v15.4s, v7.8h, v3.h[3] \n" - "smlal2 v16.4s, v7.8h, v3.h[4] \n" - "smlal2 v17.4s, v7.8h, v3.h[5] \n" - - "subs %w0, %w0, #1 \n" - - "smlal2 v18.4s, v7.8h, v3.h[6] \n" - "smlal2 v19.4s, v7.8h, v3.h[7] \n" - - "bne 0b \n" - - "sub %2, %2, #32 \n" - "sub %3, %3, #32 \n" - - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n" - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n" - - : "=r"(nn), // %0 - "=r"(output0_tm), // %1 - "=r"(r0), // %2 - "=r"(k0) // %3 - : "0"(nn), - "1"(output0_tm), - "2"(r0), - "3"(k0) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19"); - } - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 12 + (i % 12) / 8); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - int16x8_t _val4 = vld1q_s16(r0 + 32); - int16x8_t _val5 = vld1q_s16(r0 + 40); - int16x8_t _val6 = vld1q_s16(r0 + 48); - int16x8_t _val7 = vld1q_s16(r0 + 56); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w0), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val0), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w0), vget_low_s16(_val0), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_high_s16(_val0), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w0), vget_high_s16(_val0), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_high_s16(_val0), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w0), vget_high_s16(_val0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w0), vget_low_s16(_val1), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val1), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w0), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w0), vget_high_s16(_val1), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_high_s16(_val1), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w0), vget_high_s16(_val1), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_high_s16(_val1), 3); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val2), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w1), vget_low_s16(_val2), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val2), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w1), vget_low_s16(_val2), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_high_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w1), vget_high_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_high_s16(_val2), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w1), vget_high_s16(_val2), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w1), vget_low_s16(_val3), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val3), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w1), vget_low_s16(_val3), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val3), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w1), vget_high_s16(_val3), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_high_s16(_val3), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w1), vget_high_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_high_s16(_val3), 3); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val4), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w2), vget_low_s16(_val4), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val4), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w2), vget_low_s16(_val4), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_high_s16(_val4), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w2), vget_high_s16(_val4), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_high_s16(_val4), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w2), vget_high_s16(_val4), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w2), vget_low_s16(_val5), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val5), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w2), vget_low_s16(_val5), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val5), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w2), vget_high_s16(_val5), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_high_s16(_val5), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w2), vget_high_s16(_val5), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_high_s16(_val5), 3); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val6), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w3), vget_low_s16(_val6), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val6), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w3), vget_low_s16(_val6), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_high_s16(_val6), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w3), vget_high_s16(_val6), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_high_s16(_val6), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w3), vget_high_s16(_val6), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w3), vget_low_s16(_val7), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val7), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w3), vget_low_s16(_val7), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val7), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w3), vget_high_s16(_val7), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_high_s16(_val7), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w3), vget_high_s16(_val7), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_high_s16(_val7), 3); - - r0 += 64; - k0 += 32; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - vst1q_s32(output0_tm + 8, _sum2); - vst1q_s32(output0_tm + 12, _sum3); - vst1q_s32(output0_tm + 16, _sum4); - vst1q_s32(output0_tm + 20, _sum5); - vst1q_s32(output0_tm + 24, _sum6); - vst1q_s32(output0_tm + 28, _sum7); - output0_tm += 32; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __aarch64__ - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_low_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_low_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_low_s16(_val3), 0); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_low_s16(_val3), 1); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_low_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_low_s16(_val2), 3); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_low_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_low_s16(_val3), 3); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_high_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_high_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_high_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_high_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_high_s16(_val3), 0); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_high_s16(_val3), 1); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_high_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_high_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_high_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_high_s16(_val2), 3); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_high_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_high_s16(_val3), 3); - - r0 += 32; - k0 += 32; - } - - _sum0 = vaddq_s32(_sum0, _sum1); - _sum2 = vaddq_s32(_sum2, _sum3); - _sum4 = vaddq_s32(_sum4, _sum5); - _sum6 = vaddq_s32(_sum6, _sum7); - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum2); - vst1q_s32(output0_tm + 8, _sum4); - vst1q_s32(output0_tm + 12, _sum6); - output0_tm += 16; -#else - asm volatile( - "veor q8, q8 \n" - "veor q9, q9 \n" - "veor q10, q10 \n" - "veor q11, q11 \n" - "veor q12, q12 \n" - "veor q13, q13 \n" - "veor q14, q14 \n" - "veor q15, q15 \n" - - "0: \n" - - "pld [%2, #256] \n" - "pld [%2, #512] \n" - "vldm %2!, {d0-d7} \n" - - "pld [%3, #256] \n" - "vld1.s16 {d8-d11}, [%3 :128]! \n" - - "vmlal.s16 q8, d8, d0[0] \n" - "vmlal.s16 q12, d9, d0[1] \n" - "vmlal.s16 q9, d8, d2[0] \n" - "vmlal.s16 q13, d9, d2[1] \n" - "vmlal.s16 q10, d8, d4[0] \n" - "vmlal.s16 q14, d9, d4[1] \n" - "vmlal.s16 q11, d8, d6[0] \n" - "vmlal.s16 q15, d9, d6[1] \n" - - "pld [%3, #128] \n" - "vld1.s16 {d8-d9}, [%3 :128]! \n" - - "vmlal.s16 q8, d10, d0[2] \n" - "vmlal.s16 q12, d11, d0[3] \n" - "vmlal.s16 q9, d10, d2[2] \n" - "vmlal.s16 q13, d11, d2[3] \n" - "vmlal.s16 q10, d10, d4[2] \n" - "vmlal.s16 q14, d11, d4[3] \n" - "vmlal.s16 q11, d10, d6[2] \n" - "vmlal.s16 q15, d11, d6[3] \n" - - "pld [%3, #128] \n" - "vld1.s16 {d10-d11}, [%3 :128]! \n" - - "vmlal.s16 q8, d8, d1[0] \n" - "vmlal.s16 q12, d9, d1[1] \n" - "vmlal.s16 q9, d8, d3[0] \n" - "vmlal.s16 q13, d9, d3[1] \n" - "vmlal.s16 q10, d8, d5[0] \n" - "vmlal.s16 q14, d9, d5[1] \n" - "vmlal.s16 q11, d8, d7[0] \n" - "vmlal.s16 q15, d9, d7[1] \n" - - "subs %0, %0, #1 \n" - - "vmlal.s16 q8, d10, d1[2] \n" - "vmlal.s16 q12, d11, d1[3] \n" - "vmlal.s16 q9, d10, d3[2] \n" - "vmlal.s16 q13, d11, d3[3] \n" - "vmlal.s16 q10, d10, d5[2] \n" - "vmlal.s16 q14, d11, d5[3] \n" - "vmlal.s16 q11, d10, d7[2] \n" - "vmlal.s16 q15, d11, d7[3] \n" - - "bne 0b \n" - - "vadd.s32 q8, q8, q12 \n" - "vadd.s32 q9, q9, q13 \n" - "vadd.s32 q10, q10, q14 \n" - "vadd.s32 q11, q11, q15 \n" - - "vstm %1!, {d16-d23} \n" - - : "=r"(nn), - "=r"(output0_tm), - "=r"(r0), - "=r"(k0) - : "0"(nn), - "1"(output0_tm), - "2"(r0), - "3"(k0) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); -#endif - } - for (; i + 1 < tiles; i += 2) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); -#else - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_high_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_high_s16(_val1), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_high_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_high_s16(_val1), 3); - - r0 += 16; - k0 += 32; - } - - _sum0 = vaddq_s32(_sum0, _sum1); - _sum2 = vaddq_s32(_sum2, _sum3); - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum2); - output0_tm += 8; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); -#else - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); - - r0 += 8; - k0 += 32; - } - - _sum0 = vaddq_s32(_sum0, _sum1); - - vst1q_s32(output0_tm, _sum0); - output0_tm += 4; - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_transform_int8.h b/src/layer/arm/convolution_winograd_transform_int8.h deleted file mode 100644 index 4e27e8c6287..00000000000 --- a/src/layer/arm/convolution_winograd_transform_int8.h +++ /dev/null @@ -1,230 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_input_int8_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) -{ - const int w = bottom_blob.w; - const int h = bottom_blob.h; - const int inch = bottom_blob.c; - - const int w_tiles = (w - 2) / 4; - const int h_tiles = (h - 2) / 4; - const int tiles = w_tiles * h_tiles; - - // const float itm[6][6] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6]; - - // tile - for (int i = 0; i < h_tiles; i++) - { - for (int j = 0; j < w_tiles; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4); - - for (int m = 0; m < 6; m++) - { - signed char r00 = r0[0]; - signed char r01 = r0[1]; - signed char r02 = r0[2]; - signed char r03 = r0[3]; - signed char r04 = r0[4]; - signed char r05 = r0[5]; - - short tmp0m = 4 * r00 - 5 * r02 + r04; - short tmp1m = -4 * (r01 + r02) + r04 + r03; - short tmp2m = 4 * (r01 - r02) + r04 - r03; - short tmp3m = -2 * (r01 - r03) + r04 - r02; - short tmp4m = 2 * (r01 - r03) + r04 - r02; - short tmp5m = 4 * r01 - 5 * r03 + r05; - - tmp[0][m] = tmp0m; - tmp[1][m] = tmp1m; - tmp[2][m] = tmp2m; - tmp[3][m] = tmp3m; - tmp[4][m] = tmp4m; - tmp[5][m] = tmp5m; - - r0 += w; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tiles + j); - short* r0_tm_1 = r0_tm_0 + tiles; - short* r0_tm_2 = r0_tm_0 + tiles * 2; - short* r0_tm_3 = r0_tm_0 + tiles * 3; - short* r0_tm_4 = r0_tm_0 + tiles * 4; - short* r0_tm_5 = r0_tm_0 + tiles * 5; - - for (int m = 0; m < 6; m++) - { - short tmp00 = tmp[m][0]; - short tmp01 = tmp[m][1]; - short tmp02 = tmp[m][2]; - short tmp03 = tmp[m][3]; - short tmp04 = tmp[m][4]; - short tmp05 = tmp[m][5]; - - short r0tm0 = 4 * tmp00 - 5 * tmp02 + tmp04; - short r0tm1 = -4 * (tmp01 + tmp02) + tmp04 + tmp03; - short r0tm2 = 4 * (tmp01 - tmp02) + tmp04 - tmp03; - short r0tm3 = -2 * (tmp01 - tmp03) + tmp04 - tmp02; - short r0tm4 = 2 * (tmp01 - tmp03) + tmp04 - tmp02; - short r0tm5 = 4 * tmp01 - 5 * tmp03 + tmp05; - - r0_tm_0[0] = r0tm0; - r0_tm_1[0] = r0tm1; - r0_tm_2[0] = r0tm2; - r0_tm_3[0] = r0tm3; - r0_tm_4[0] = r0tm4; - r0_tm_5[0] = r0tm5; - - r0_tm_0 += tiles * 6; - r0_tm_1 += tiles * 6; - r0_tm_2 += tiles * 6; - r0_tm_3 += tiles * 6; - r0_tm_4 += tiles * 6; - r0_tm_5 += tiles * 6; - } - } - } - } -} - -static void conv3x3s1_winograd43_transform_output_int8_neon(const Mat& top_blob_tm, Mat& top_blob, const Option& opt) -{ - const int outw = top_blob.w; - const int outh = top_blob.h; - const int outch = top_blob.c; - - const int w_tiles = outw / 4; - const int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob.channel(p); - - int tmp[4][6]; - - // tile - for (int i = 0; i < h_tiles; i++) - { - for (int j = 0; j < w_tiles; j++) - { - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tiles + j) * 1; - const int* output0_tm_1 = output0_tm_0 + tiles * 1; - const int* output0_tm_2 = output0_tm_0 + tiles * 2; - const int* output0_tm_3 = output0_tm_0 + tiles * 3; - const int* output0_tm_4 = output0_tm_0 + tiles * 4; - const int* output0_tm_5 = output0_tm_0 + tiles * 5; - - int* output0 = out0.row(i * 4) + j * 4; - - // TODO neon optimize - for (int m = 0; m < 5; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = output0_tm_0[0] + tmp02a + tmp02b; - tmp[1][m] = tmp13a + tmp13b * 2; - tmp[2][m] = tmp02a + tmp02b * 4; - tmp[3][m] = output0_tm_5[0] * 4 + tmp13a + tmp13b * 8; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - for (int m = 5; m < 6; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = (output0_tm_0[0] + tmp02a + tmp02b) * 4; - tmp[1][m] = (tmp13a + tmp13b * 2) * 4; - tmp[2][m] = (tmp02a + tmp02b * 4) * 4; - tmp[3][m] = (output0_tm_5[0] * 4 + tmp13a + tmp13b * 8) * 4; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - - for (int m = 0; m < 4; m++) - { - const int* tmp0 = tmp[m]; - - int tmp02a = tmp0[1] + tmp0[2]; - int tmp13a = tmp0[1] - tmp0[2]; - - int tmp02b = tmp0[3] + tmp0[4]; - int tmp13b = tmp0[3] - tmp0[4]; - - output0[0] = (tmp0[0] + tmp02a + tmp02b) / 576; - output0[1] = (tmp13a + tmp13b * 2) / 576; - output0[2] = (tmp02a + tmp02b * 4) / 576; - output0[3] = (tmp0[5] + tmp13a + tmp13b * 8) / 576; - - output0 += outw; - } - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_transform_pack4_int8.h b/src/layer/arm/convolution_winograd_transform_pack4_int8.h deleted file mode 100644 index fff5f7d6650..00000000000 --- a/src/layer/arm/convolution_winograd_transform_pack4_int8.h +++ /dev/null @@ -1,178 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_output_pack4_int8_neon(const Mat& top_blob_tm, Mat& top_blob, const Option& opt) -{ - const int outw = top_blob.w; - const int outh = top_blob.h; - const int outch = top_blob.c; - - const int w_tiles = outw / 4; - const int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob.channel(p); - - int tmp[4][6][4]; - - // tile - for (int i = 0; i < h_tiles; i++) - { - for (int j = 0; j < w_tiles; j++) - { - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tiles + j) * 4; - const int* output0_tm_1 = output0_tm_0 + tiles * 4; - const int* output0_tm_2 = output0_tm_0 + tiles * 8; - const int* output0_tm_3 = output0_tm_0 + tiles * 12; - const int* output0_tm_4 = output0_tm_0 + tiles * 16; - const int* output0_tm_5 = output0_tm_0 + tiles * 20; - - int* output0 = out0.row(i * 4) + (j * 4) * 4; - - for (int m = 0; m < 5; m++) - { - int32x4_t _out0tm0 = vld1q_s32(output0_tm_0); - int32x4_t _out0tm1 = vld1q_s32(output0_tm_1); - int32x4_t _out0tm2 = vld1q_s32(output0_tm_2); - int32x4_t _out0tm3 = vld1q_s32(output0_tm_3); - int32x4_t _out0tm4 = vld1q_s32(output0_tm_4); - int32x4_t _out0tm5 = vld1q_s32(output0_tm_5); - - int32x4_t _tmp02a = vaddq_s32(_out0tm1, _out0tm2); - int32x4_t _tmp13a = vsubq_s32(_out0tm1, _out0tm2); - - int32x4_t _tmp02b = vaddq_s32(_out0tm3, _out0tm4); - int32x4_t _tmp13b = vsubq_s32(_out0tm3, _out0tm4); - - int32x4_t _v2 = vdupq_n_s32(2); - int32x4_t _v4 = vdupq_n_s32(4); - int32x4_t _v8 = vdupq_n_s32(8); - - int32x4_t _tmp0m = vaddq_s32(vaddq_s32(_out0tm0, _tmp02a), _tmp02b); - int32x4_t _tmp1m = vmlaq_s32(_tmp13a, _tmp13b, _v2); - int32x4_t _tmp2m = vmlaq_s32(_tmp02a, _tmp02b, _v4); - int32x4_t _tmp3m = vmlaq_s32(vmlaq_s32(_tmp13a, _out0tm5, _v4), _tmp13b, _v8); - - vst1q_s32(tmp[0][m], _tmp0m); - vst1q_s32(tmp[1][m], _tmp1m); - vst1q_s32(tmp[2][m], _tmp2m); - vst1q_s32(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - for (int m = 5; m < 6; m++) - { - int32x4_t _out0tm0 = vld1q_s32(output0_tm_0); - int32x4_t _out0tm1 = vld1q_s32(output0_tm_1); - int32x4_t _out0tm2 = vld1q_s32(output0_tm_2); - int32x4_t _out0tm3 = vld1q_s32(output0_tm_3); - int32x4_t _out0tm4 = vld1q_s32(output0_tm_4); - int32x4_t _out0tm5 = vld1q_s32(output0_tm_5); - - int32x4_t _tmp02a = vaddq_s32(_out0tm1, _out0tm2); - int32x4_t _tmp13a = vsubq_s32(_out0tm1, _out0tm2); - - int32x4_t _tmp02b = vaddq_s32(_out0tm3, _out0tm4); - int32x4_t _tmp13b = vsubq_s32(_out0tm3, _out0tm4); - - int32x4_t _v2 = vdupq_n_s32(2); - int32x4_t _v4 = vdupq_n_s32(4); - int32x4_t _v8 = vdupq_n_s32(8); - - int32x4_t _tmp0m = vaddq_s32(vaddq_s32(_out0tm0, _tmp02a), _tmp02b); - int32x4_t _tmp1m = vmlaq_s32(_tmp13a, _tmp13b, _v2); - int32x4_t _tmp2m = vmlaq_s32(_tmp02a, _tmp02b, _v4); - int32x4_t _tmp3m = vmlaq_s32(vmlaq_s32(_tmp13a, _out0tm5, _v4), _tmp13b, _v8); - - _tmp0m = vmulq_s32(_tmp0m, _v4); - _tmp1m = vmulq_s32(_tmp1m, _v4); - _tmp2m = vmulq_s32(_tmp2m, _v4); - _tmp3m = vmulq_s32(_tmp3m, _v4); - - vst1q_s32(tmp[0][m], _tmp0m); - vst1q_s32(tmp[1][m], _tmp1m); - vst1q_s32(tmp[2][m], _tmp2m); - vst1q_s32(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - - for (int m = 0; m < 4; m++) - { - int32x4_t _tmp00 = vld1q_s32(tmp[m][0]); - int32x4_t _tmp01 = vld1q_s32(tmp[m][1]); - int32x4_t _tmp02 = vld1q_s32(tmp[m][2]); - int32x4_t _tmp03 = vld1q_s32(tmp[m][3]); - int32x4_t _tmp04 = vld1q_s32(tmp[m][4]); - int32x4_t _tmp05 = vld1q_s32(tmp[m][5]); - - int32x4_t _tmp02a = vaddq_s32(_tmp01, _tmp02); - int32x4_t _tmp13a = vsubq_s32(_tmp01, _tmp02); - - int32x4_t _tmp02b = vaddq_s32(_tmp03, _tmp04); - int32x4_t _tmp13b = vsubq_s32(_tmp03, _tmp04); - - int32x4_t _v2 = vdupq_n_s32(2); - int32x4_t _v4 = vdupq_n_s32(4); - int32x4_t _v8 = vdupq_n_s32(8); - - int32x4_t _out00 = vaddq_s32(vaddq_s32(_tmp00, _tmp02a), _tmp02b); - int32x4_t _out01 = vmlaq_s32(_tmp13a, _tmp13b, _v2); - int32x4_t _out02 = vmlaq_s32(_tmp02a, _tmp02b, _v4); - int32x4_t _out03 = vmlaq_s32(vaddq_s32(_tmp05, _tmp13a), _tmp13b, _v8); - - // TODO use integer trick for division by 576 - float32x4_t _v576 = vdupq_n_f32(1.0 / 576); - _out00 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out00), _v576)); - _out01 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out01), _v576)); - _out02 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out02), _v576)); - _out03 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out03), _v576)); - - vst1q_s32(output0, _out00); - vst1q_s32(output0 + 4, _out01); - vst1q_s32(output0 + 8, _out02); - vst1q_s32(output0 + 12, _out03); - - output0 += outw * 4; - } - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_transform_pack8_int8.h b/src/layer/arm/convolution_winograd_transform_pack8_int8.h deleted file mode 100644 index f0d8981ef77..00000000000 --- a/src/layer/arm/convolution_winograd_transform_pack8_int8.h +++ /dev/null @@ -1,131 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_input_pack8_int8_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) -{ - const int w = bottom_blob.w; - const int h = bottom_blob.h; - const int inch = bottom_blob.c; - - const int w_tiles = (w - 2) / 4; - const int h_tiles = (h - 2) / 4; - const int tiles = w_tiles * h_tiles; - - // const float itm[6][6] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tiles; i++) - { - for (int j = 0; j < w_tiles; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - int8x8_t _r00 = vld1_s8(r0); - int8x8_t _r01 = vld1_s8(r0 + 8); - int8x8_t _r02 = vld1_s8(r0 + 16); - int8x8_t _r03 = vld1_s8(r0 + 24); - int8x8_t _r04 = vld1_s8(r0 + 32); - int8x8_t _r05 = vld1_s8(r0 + 40); - - int8x8_t _v4s8 = vdup_n_s8(4); - int8x8_t _v5s8 = vdup_n_s8(5); - int16x8_t _v2 = vdupq_n_s16(2); - int16x8_t _v4 = vdupq_n_s16(4); - - int16x8_t _tmp0m = vsubq_s16(vaddw_s8(vmull_s8(_r00, _v4s8), _r04), vmull_s8(_r02, _v5s8)); - int16x8_t _tmp1m = vmlsq_s16(vaddl_s8(_r04, _r03), vaddl_s8(_r01, _r02), _v4); - int16x8_t _tmp2m = vmlaq_s16(vsubl_s8(_r04, _r03), vsubl_s8(_r01, _r02), _v4); - int16x8_t _tmp3m = vmlsq_s16(vsubl_s8(_r04, _r02), vsubl_s8(_r01, _r03), _v2); - int16x8_t _tmp4m = vmlaq_s16(vsubl_s8(_r04, _r02), vsubl_s8(_r01, _r03), _v2); - int16x8_t _tmp5m = vsubq_s16(vaddw_s8(vmull_s8(_r01, _v4s8), _r05), vmull_s8(_r03, _v5s8)); - - vst1q_s16(tmp[0][m], _tmp0m); - vst1q_s16(tmp[1][m], _tmp1m); - vst1q_s16(tmp[2][m], _tmp2m); - vst1q_s16(tmp[3][m], _tmp3m); - vst1q_s16(tmp[4][m], _tmp4m); - vst1q_s16(tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tiles + j) * 8; - short* r0_tm_1 = r0_tm_0 + tiles * 8; - short* r0_tm_2 = r0_tm_0 + tiles * 16; - short* r0_tm_3 = r0_tm_0 + tiles * 24; - short* r0_tm_4 = r0_tm_0 + tiles * 32; - short* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - int16x8_t _tmp00 = vld1q_s16(tmp[m][0]); - int16x8_t _tmp01 = vld1q_s16(tmp[m][1]); - int16x8_t _tmp02 = vld1q_s16(tmp[m][2]); - int16x8_t _tmp03 = vld1q_s16(tmp[m][3]); - int16x8_t _tmp04 = vld1q_s16(tmp[m][4]); - int16x8_t _tmp05 = vld1q_s16(tmp[m][5]); - - int16x8_t _v2 = vdupq_n_s16(2); - int16x8_t _v4 = vdupq_n_s16(4); - int16x8_t _v5 = vdupq_n_s16(5); - - int16x8_t _r0tm0 = vmlsq_s16(vmlaq_s16(_tmp04, _tmp00, _v4), _tmp02, _v5); - int16x8_t _r0tm1 = vmlsq_s16(vaddq_s16(_tmp04, _tmp03), vaddq_s16(_tmp01, _tmp02), _v4); - int16x8_t _r0tm2 = vmlaq_s16(vsubq_s16(_tmp04, _tmp03), vsubq_s16(_tmp01, _tmp02), _v4); - int16x8_t _r0tm3 = vmlsq_s16(vsubq_s16(_tmp04, _tmp02), vsubq_s16(_tmp01, _tmp03), _v2); - int16x8_t _r0tm4 = vmlaq_s16(vsubq_s16(_tmp04, _tmp02), vsubq_s16(_tmp01, _tmp03), _v2); - int16x8_t _r0tm5 = vmlsq_s16(vmlaq_s16(_tmp05, _tmp01, _v4), _tmp03, _v5); - - vst1q_s16(r0_tm_0, _r0tm0); - vst1q_s16(r0_tm_1, _r0tm1); - vst1q_s16(r0_tm_2, _r0tm2); - vst1q_s16(r0_tm_3, _r0tm3); - vst1q_s16(r0_tm_4, _r0tm4); - vst1q_s16(r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } -} diff --git a/src/layer/arm/gelu_arm.cpp b/src/layer/arm/gelu_arm.cpp index 3ae329a3a28..80d4efba0cb 100644 --- a/src/layer/arm/gelu_arm.cpp +++ b/src/layer/arm/gelu_arm.cpp @@ -14,8 +14,6 @@ #include "gelu_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/gelu_arm_asimdhp.cpp b/src/layer/arm/gelu_arm_asimdhp.cpp index 78514dbc042..ea8b159cfa8 100644 --- a/src/layer/arm/gelu_arm_asimdhp.cpp +++ b/src/layer/arm/gelu_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "gelu_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/gru_arm.cpp b/src/layer/arm/gru_arm.cpp index aa927d26a58..70df351a555 100644 --- a/src/layer/arm/gru_arm.cpp +++ b/src/layer/arm/gru_arm.cpp @@ -14,8 +14,6 @@ #include "gru_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/gru_arm_asimdhp.cpp b/src/layer/arm/gru_arm_asimdhp.cpp index f5e74b50284..ae657fc301b 100644 --- a/src/layer/arm/gru_arm_asimdhp.cpp +++ b/src/layer/arm/gru_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "gru_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/innerproduct_arm.h b/src/layer/arm/innerproduct_arm.h index 1eff44c7b1d..f1eee178f9c 100644 --- a/src/layer/arm/innerproduct_arm.h +++ b/src/layer/arm/innerproduct_arm.h @@ -16,8 +16,6 @@ #define LAYER_INNERPRODUCT_ARM_H #include "innerproduct.h" -#include -#include namespace ncnn { diff --git a/src/layer/arm/interp_arm.cpp b/src/layer/arm/interp_arm.cpp index 1ee97d57996..191499aa26b 100644 --- a/src/layer/arm/interp_arm.cpp +++ b/src/layer/arm/interp_arm.cpp @@ -14,8 +14,6 @@ #include "interp_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/interp_arm_asimdhp.cpp b/src/layer/arm/interp_arm_asimdhp.cpp index c9bf14b1077..286c74fe40c 100644 --- a/src/layer/arm/interp_arm_asimdhp.cpp +++ b/src/layer/arm/interp_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "interp_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/lrn_arm.cpp b/src/layer/arm/lrn_arm.cpp index fdc05c3f952..f763bfb2a2f 100644 --- a/src/layer/arm/lrn_arm.cpp +++ b/src/layer/arm/lrn_arm.cpp @@ -14,8 +14,6 @@ #include "lrn_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/lstm_arm.cpp b/src/layer/arm/lstm_arm.cpp index 79a0c97c917..04d7277547e 100644 --- a/src/layer/arm/lstm_arm.cpp +++ b/src/layer/arm/lstm_arm.cpp @@ -14,8 +14,6 @@ #include "lstm_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/lstm_arm_asimdhp.cpp b/src/layer/arm/lstm_arm_asimdhp.cpp index a394bad4c2e..8a3ee63e40a 100644 --- a/src/layer/arm/lstm_arm_asimdhp.cpp +++ b/src/layer/arm/lstm_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "lstm_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/mish_arm.cpp b/src/layer/arm/mish_arm.cpp index 54757380d0c..31c9f77df63 100644 --- a/src/layer/arm/mish_arm.cpp +++ b/src/layer/arm/mish_arm.cpp @@ -14,8 +14,6 @@ #include "mish_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/mish_arm_asimdhp.cpp b/src/layer/arm/mish_arm_asimdhp.cpp index e8db14d3e41..0e04883370e 100644 --- a/src/layer/arm/mish_arm_asimdhp.cpp +++ b/src/layer/arm/mish_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "mish_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/quantize_arm.cpp b/src/layer/arm/quantize_arm.cpp index aa2a61a3472..6e395a9bb76 100644 --- a/src/layer/arm/quantize_arm.cpp +++ b/src/layer/arm/quantize_arm.cpp @@ -15,8 +15,6 @@ #include "quantize_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/quantize_arm_asimdhp.cpp b/src/layer/arm/quantize_arm_asimdhp.cpp index d3a66271654..faccb907b41 100644 --- a/src/layer/arm/quantize_arm_asimdhp.cpp +++ b/src/layer/arm/quantize_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "quantize_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/requantize_arm.cpp b/src/layer/arm/requantize_arm.cpp index 4d4531e9438..32fdd961433 100644 --- a/src/layer/arm/requantize_arm.cpp +++ b/src/layer/arm/requantize_arm.cpp @@ -15,8 +15,6 @@ #include "requantize_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/rnn_arm.cpp b/src/layer/arm/rnn_arm.cpp index 87892d7ada2..19f439ea2d5 100644 --- a/src/layer/arm/rnn_arm.cpp +++ b/src/layer/arm/rnn_arm.cpp @@ -14,8 +14,6 @@ #include "rnn_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/rnn_arm_asimdhp.cpp b/src/layer/arm/rnn_arm_asimdhp.cpp index 79fb0b1db1e..c34b3e8bb48 100644 --- a/src/layer/arm/rnn_arm_asimdhp.cpp +++ b/src/layer/arm/rnn_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "rnn_arm.h" -#include - #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/layer/arm/selu_arm.cpp b/src/layer/arm/selu_arm.cpp index 219cd6d4fdf..afe360cd61b 100644 --- a/src/layer/arm/selu_arm.cpp +++ b/src/layer/arm/selu_arm.cpp @@ -26,8 +26,9 @@ int SELU_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { int w = bottom_top_blob.w; int h = bottom_top_blob.h; + int d = bottom_top_blob.d; int channels = bottom_top_blob.c; - int size = w * h; + int size = w * h * d; float alphaxlambda = alpha * lambda; #pragma omp parallel for num_threads(opt.num_threads) diff --git a/src/layer/arm/sigmoid_arm.cpp b/src/layer/arm/sigmoid_arm.cpp index fb79c4d56c1..af2b396dd5e 100644 --- a/src/layer/arm/sigmoid_arm.cpp +++ b/src/layer/arm/sigmoid_arm.cpp @@ -14,8 +14,6 @@ #include "sigmoid_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/sigmoid_arm_asimdhp.cpp b/src/layer/arm/sigmoid_arm_asimdhp.cpp index 3e5e6cd830d..65c32ee3e67 100644 --- a/src/layer/arm/sigmoid_arm_asimdhp.cpp +++ b/src/layer/arm/sigmoid_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "sigmoid_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/softmax_arm.cpp b/src/layer/arm/softmax_arm.cpp index 81907555469..48faaf91061 100644 --- a/src/layer/arm/softmax_arm.cpp +++ b/src/layer/arm/softmax_arm.cpp @@ -15,7 +15,6 @@ #include "softmax_arm.h" #include -#include #if __ARM_NEON #include diff --git a/src/layer/arm/softmax_arm_asimdhp.cpp b/src/layer/arm/softmax_arm_asimdhp.cpp index 2460a92f435..d8efaf4c3b9 100644 --- a/src/layer/arm/softmax_arm_asimdhp.cpp +++ b/src/layer/arm/softmax_arm_asimdhp.cpp @@ -15,7 +15,6 @@ #include "softmax_arm.h" #include -#include #if __ARM_NEON #include diff --git a/src/layer/arm/swish_arm.cpp b/src/layer/arm/swish_arm.cpp index 8b2ff9a01e5..d68e617276c 100644 --- a/src/layer/arm/swish_arm.cpp +++ b/src/layer/arm/swish_arm.cpp @@ -14,8 +14,6 @@ #include "swish_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/swish_arm_asimdhp.cpp b/src/layer/arm/swish_arm_asimdhp.cpp index 5a598f67501..4aee8a898c4 100644 --- a/src/layer/arm/swish_arm_asimdhp.cpp +++ b/src/layer/arm/swish_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "swish_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/tanh_arm.cpp b/src/layer/arm/tanh_arm.cpp index 0b9dd5c95e8..6e86d7ad300 100644 --- a/src/layer/arm/tanh_arm.cpp +++ b/src/layer/arm/tanh_arm.cpp @@ -14,8 +14,6 @@ #include "tanh_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/tanh_arm_asimdhp.cpp b/src/layer/arm/tanh_arm_asimdhp.cpp index e9297aa71a7..10f3303a1ce 100644 --- a/src/layer/arm/tanh_arm_asimdhp.cpp +++ b/src/layer/arm/tanh_arm_asimdhp.cpp @@ -14,8 +14,6 @@ #include "tanh_arm.h" -#include - #if __ARM_NEON #include #include "neon_mathfun.h" diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 5a054cc7c4d..e2dbd68c3a4 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -14,9 +14,8 @@ #include "unaryop_arm.h" -#include +// #include #include -#include #if __ARM_NEON #include diff --git a/src/layer/arm/unaryop_arm_asimdhp.cpp b/src/layer/arm/unaryop_arm_asimdhp.cpp index 02532db4114..ac64fc708f9 100644 --- a/src/layer/arm/unaryop_arm_asimdhp.cpp +++ b/src/layer/arm/unaryop_arm_asimdhp.cpp @@ -14,9 +14,8 @@ #include "unaryop_arm.h" -#include +// #include #include -#include #if __ARM_NEON #include diff --git a/src/layer/batchnorm.cpp b/src/layer/batchnorm.cpp index cf0f871e58f..b13e5ef2966 100644 --- a/src/layer/batchnorm.cpp +++ b/src/layer/batchnorm.cpp @@ -14,8 +14,6 @@ #include "batchnorm.h" -#include - namespace ncnn { BatchNorm::BatchNorm() diff --git a/src/layer/binaryop.cpp b/src/layer/binaryop.cpp index 0ffaf80e391..52d3d083b31 100644 --- a/src/layer/binaryop.cpp +++ b/src/layer/binaryop.cpp @@ -14,8 +14,6 @@ #include "binaryop.h" -#include - namespace ncnn { BinaryOp::BinaryOp() diff --git a/src/layer/bnll.cpp b/src/layer/bnll.cpp index 72c2ab16170..9341ebcfcec 100644 --- a/src/layer/bnll.cpp +++ b/src/layer/bnll.cpp @@ -14,8 +14,6 @@ #include "bnll.h" -#include - namespace ncnn { BNLL::BNLL() diff --git a/src/layer/celu.cpp b/src/layer/celu.cpp index 58782f877cb..8c17244c0eb 100644 --- a/src/layer/celu.cpp +++ b/src/layer/celu.cpp @@ -14,8 +14,6 @@ #include "celu.h" -#include - namespace ncnn { CELU::CELU() diff --git a/src/layer/detectionoutput.cpp b/src/layer/detectionoutput.cpp index 266beaca75a..f90b904789b 100644 --- a/src/layer/detectionoutput.cpp +++ b/src/layer/detectionoutput.cpp @@ -14,8 +14,6 @@ #include "detectionoutput.h" -#include - namespace ncnn { DetectionOutput::DetectionOutput() diff --git a/src/layer/dropout.cpp b/src/layer/dropout.cpp index f64f7ea3008..9e5ddaa17b5 100644 --- a/src/layer/dropout.cpp +++ b/src/layer/dropout.cpp @@ -14,8 +14,6 @@ #include "dropout.h" -#include - namespace ncnn { Dropout::Dropout() diff --git a/src/layer/elu.cpp b/src/layer/elu.cpp index 0a3574e58b6..e710d4f1cc5 100644 --- a/src/layer/elu.cpp +++ b/src/layer/elu.cpp @@ -14,8 +14,6 @@ #include "elu.h" -#include - namespace ncnn { ELU::ELU() @@ -35,8 +33,9 @@ int ELU::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { int w = bottom_top_blob.w; int h = bottom_top_blob.h; + int d = bottom_top_blob.d; int channels = bottom_top_blob.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) diff --git a/src/layer/erf.cpp b/src/layer/erf.cpp index c5f56e835f0..8b455919ab2 100644 --- a/src/layer/erf.cpp +++ b/src/layer/erf.cpp @@ -13,7 +13,6 @@ // specific language governing permissions and limitations under the License. #include "erf.h" -#include namespace ncnn { diff --git a/src/layer/exp.cpp b/src/layer/exp.cpp index ea8bf7dbda7..83644a7934d 100644 --- a/src/layer/exp.cpp +++ b/src/layer/exp.cpp @@ -14,8 +14,6 @@ #include "exp.h" -#include - namespace ncnn { Exp::Exp() diff --git a/src/layer/fused_activation.h b/src/layer/fused_activation.h index a331a6df5da..275fd9e2f9a 100644 --- a/src/layer/fused_activation.h +++ b/src/layer/fused_activation.h @@ -15,7 +15,6 @@ #ifndef FUSED_ACTIVATION_H #define FUSED_ACTIVATION_H -#include #include "mat.h" #include "layer_type.h" diff --git a/src/layer/gelu.cpp b/src/layer/gelu.cpp index 32b2b89954f..d1072653774 100644 --- a/src/layer/gelu.cpp +++ b/src/layer/gelu.cpp @@ -14,8 +14,6 @@ #include "gelu.h" -#include - namespace ncnn { GELU::GELU() diff --git a/src/layer/glu.cpp b/src/layer/glu.cpp index 9555b88c645..8f8e057e9a4 100644 --- a/src/layer/glu.cpp +++ b/src/layer/glu.cpp @@ -14,8 +14,6 @@ #include "glu.h" -#include - namespace ncnn { GLU::GLU() diff --git a/src/layer/gridsample.cpp b/src/layer/gridsample.cpp index 31405047ec1..abeec6fa5be 100644 --- a/src/layer/gridsample.cpp +++ b/src/layer/gridsample.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // coord compliance with the License. You may obtain a copy of the License at @@ -14,8 +14,6 @@ #include "gridsample.h" -#include - namespace ncnn { GridSample::GridSample() @@ -29,6 +27,7 @@ int GridSample::load_param(const ParamDict& pd) sample_type = pd.get(0, 1); padding_mode = pd.get(1, 1); align_corner = pd.get(2, 0); + permute_fusion = pd.get(3, 0); if (sample_type < 1 || sample_type > 3) { @@ -59,19 +58,19 @@ static float grid_sample_unormalize(int w, float coordx, int align_corner) return align_corner ? (coordx + 1) / 2.f * (w - 1) : ((coordx + 1) * w - 1) / 2.f; } -static float border_coord(int x, int border) +static float border_coord(float x, float border) { - return std::min(border, std::max(x, 0)); + return std::min(border, std::max(x, 0.0f)); } static float reflect_coord(float x, int high) { - x = abs(x); - x = high - abs(x - high); + x = fabs(x); + x = high - fabs(x - high); return x; } -static int compute_coord(int sx, int w, int padding_mode, int align_corner) +static float compute_coord(float sx, int w, int padding_mode, int align_corner) { if (padding_mode == 2) // border { @@ -85,7 +84,7 @@ static int compute_coord(int sx, int w, int padding_mode, int align_corner) } else { - sx = static_cast(reflect_coord(sx + 0.5f, w) - 0.5f); + sx = reflect_coord(sx + 0.5, w) - 0.5; sx = border_coord(sx, w - 1); } } @@ -110,7 +109,7 @@ static float get_value_bounded(const Mat& image, int x, int y) static float get_value_bounded(const Mat& image, int x, int y, int z) { - return in_bounds(image, x, y, z) ? image.channel(z).row(y)[x] : 0.f; + return in_bounds(image, x, y, z) ? image.depth(z).row(y)[x] : 0.f; } static float get_value_bounded(const Mat& image, int x, int y, int padding_mode, int align_corner) @@ -121,15 +120,6 @@ static float get_value_bounded(const Mat& image, int x, int y, int padding_mode, return get_value_bounded(image, x, y); } -static float get_value_bounded(const Mat& image, int x, int y, int z, int padding_mode, int align_corner) -{ - x = compute_coord(x, image.w, padding_mode, align_corner); - y = compute_coord(y, image.h, padding_mode, align_corner); - z = compute_coord(z, image.c, padding_mode, align_corner); - - return get_value_bounded(image, x, y, z); -} - static inline void interpolate_cubic(float fx, float* coeffs) { const float A = -0.75f; @@ -160,45 +150,102 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& if (dims == 3) { - int outw = grid.h; - int outh = grid.c; + int outw = permute_fusion == 0 ? grid.h : grid.w; + int outh = permute_fusion == 0 ? grid.c : grid.h; top_blob.create(outw, outh, channels, elemsize, opt.blob_allocator); - if (top_blob.empty()) + + Mat offset_blob; + offset_blob.create(outw, outh, grid.c, elemsize, opt.workspace_allocator); + + if (top_blob.empty() || offset_blob.empty()) return -100; - if (sample_type == 1) // bilinear + //pre-calculate all interpolation offsets for each x y, unpack grid on-the-fly + if (permute_fusion == 0) + { + float* offsetptr_x = offset_blob.channel(0); + float* offsetptr_y = offset_blob.channel(1); + + for (int y = 0; y < outh; y++) + { + const float* gridptr = grid.channel(y); + for (int x = 0; x < outw; x++) + { + float sample_x = gridptr[0]; + float sample_y = gridptr[1]; + + sample_x = grid_sample_unormalize(w, sample_x, align_corner); + sample_y = grid_sample_unormalize(h, sample_y, align_corner); + + *offsetptr_x = sample_x; + *offsetptr_y = sample_y; + + gridptr += 2; + offsetptr_x++; + offsetptr_y++; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + float* offsetptr_x = offset_blob.channel(0); + float* offsetptr_y = offset_blob.channel(1); + + for (int y = 0; y < outh; y++) + { + for (int x = 0; x < outw; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + + sample_x = grid_sample_unormalize(w, sample_x, align_corner); + sample_y = grid_sample_unormalize(h, sample_y, align_corner); + + *offsetptr_x = sample_x; + *offsetptr_y = sample_y; + + gridptr_x++; + gridptr_y++; + offsetptr_x++; + offsetptr_y++; + } + } + } + + if (sample_type == Interpolation_BILINEAR) // bilinear { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); for (int y = 0; y < outh; y++) { - const float* gridptr = grid.channel(y); - for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; // bilinear interpolate float v; { - int x0 = (int)floor(sample_x); - int y0 = (int)floor(sample_y); + sample_x = compute_coord(sample_x, w, padding_mode, align_corner); + sample_y = compute_coord(sample_y, h, padding_mode, align_corner); + int x0 = floor(sample_x); + int y0 = floor(sample_y); int x1 = x0 + 1; int y1 = y0 + 1; - float v00 = get_value_bounded(image, x0, y0, padding_mode, align_corner); - float v01 = get_value_bounded(image, x1, y0, padding_mode, align_corner); - float v10 = get_value_bounded(image, x0, y1, padding_mode, align_corner); - float v11 = get_value_bounded(image, x1, y1, padding_mode, align_corner); + float v00 = get_value_bounded(image, x0, y0); + float v01 = get_value_bounded(image, x1, y0); + float v10 = get_value_bounded(image, x0, y1); + float v11 = get_value_bounded(image, x1, y1); float alpha = sample_x - x0; float beta = sample_y - y0; @@ -212,63 +259,61 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& outptr[0] = v; outptr += 1; - gridptr += 2; + offsetptr_x++; + offsetptr_y++; } } } } - else if (sample_type == 2) // nearest + else if (sample_type == Interpolation_NEAREST) // nearest { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); for (int y = 0; y < outh; y++) { - const float* gridptr = grid.channel(y); - for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; + sample_x = compute_coord(sample_x, w, padding_mode, align_corner); + sample_y = compute_coord(sample_y, h, padding_mode, align_corner); - int x0 = static_cast(round(sample_x)); - int y0 = static_cast(round(sample_y)); + int x0 = static_cast(floor(sample_x + 0.5f)); + int y0 = static_cast(floor(sample_y + 0.5f)); - float v = get_value_bounded(image, x0, y0, padding_mode, align_corner); + float v = get_value_bounded(image, x0, y0); outptr[0] = v; outptr += 1; - gridptr += 2; + offsetptr_x++; + offsetptr_y++; } } } } - else if (sample_type == 3) // bicubic + else if (sample_type == Interpolation_BICUBIC) // bicubic { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); for (int y = 0; y < outh; y++) { - const float* gridptr = grid.channel(y); - for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; // bicubic interpolate float v; @@ -315,7 +360,8 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& outptr[0] = v; outptr += 1; - gridptr += 2; + offsetptr_x++; + offsetptr_y++; } } } @@ -324,37 +370,120 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& if (dims == 4) { - int outw = grid.h; - int outh = grid.d; - int outd = grid.c; + int outw = permute_fusion == 0 ? grid.h : grid.w; + int outh = permute_fusion == 0 ? grid.d : grid.h; + int outd = permute_fusion == 0 ? grid.c : grid.d; top_blob.create(outw, outh, outd, channels, elemsize, opt.blob_allocator); - if (top_blob.empty()) + + Mat offset_blob; + offset_blob.create(outw, outh, outd, grid.c, elemsize, opt.workspace_allocator); + + if (top_blob.empty() || offset_blob.empty()) return -100; - if (sample_type == 1) // bilinear + //pre-calculate all interpolation offsets for each x y, unpack grid on-the-fly + if (permute_fusion == 0) + { + float* offsetptr_x = offset_blob.channel(0); + float* offsetptr_y = offset_blob.channel(1); + float* offsetptr_z = offset_blob.channel(2); + + for (int z = 0; z < outd; z++) + { + const float* gridptr = grid.channel(z); + for (int y = 0; y < outh; y++) + { + for (int x = 0; x < outw; x++) + { + float sample_x = gridptr[0]; + float sample_y = gridptr[1]; + float sample_z = gridptr[2]; + + sample_x = grid_sample_unormalize(w, sample_x, align_corner); + sample_x = compute_coord(sample_x, w, padding_mode, align_corner); + + sample_y = grid_sample_unormalize(h, sample_y, align_corner); + sample_y = compute_coord(sample_y, h, padding_mode, align_corner); + + sample_z = grid_sample_unormalize(d, sample_z, align_corner); + sample_z = compute_coord(sample_z, d, padding_mode, align_corner); + + *offsetptr_x = sample_x; + *offsetptr_y = sample_y; + *offsetptr_z = sample_z; + + gridptr += 3; + offsetptr_x++; + offsetptr_y++; + offsetptr_z++; + } + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + const float* gridptr_z = grid.channel(2); + float* offsetptr_x = offset_blob.channel(0); + float* offsetptr_y = offset_blob.channel(1); + float* offsetptr_z = offset_blob.channel(2); + + for (int z = 0; z < outd; z++) + { + for (int y = 0; y < outh; y++) + { + for (int x = 0; x < outw; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + float sample_z = *gridptr_z; + + sample_x = grid_sample_unormalize(w, sample_x, align_corner); + sample_x = compute_coord(sample_x, w, padding_mode, align_corner); + + sample_y = grid_sample_unormalize(h, sample_y, align_corner); + sample_y = compute_coord(sample_y, h, padding_mode, align_corner); + + sample_z = grid_sample_unormalize(d, sample_z, align_corner); + sample_z = compute_coord(sample_z, d, padding_mode, align_corner); + + *offsetptr_x = sample_x; + *offsetptr_y = sample_y; + *offsetptr_z = sample_z; + + gridptr_x++; + gridptr_y++; + gridptr_z++; + offsetptr_x++; + offsetptr_y++; + offsetptr_z++; + } + } + } + } + + if (sample_type == Interpolation_BILINEAR) // bilinear { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); + const float* offsetptr_z = offset_blob.channel(2); for (int z = 0; z < outd; z++) { - const float* gridptr = grid.channel(z); - for (int y = 0; y < outh; y++) { for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - float sample_z = gridptr[2]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); - sample_z = grid_sample_unormalize(d, sample_z, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; + float sample_z = *offsetptr_z; // bilinear interpolate float v; @@ -366,14 +495,14 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& int y1 = y0 + 1; int z1 = z0 + 1; - float v000 = get_value_bounded(image, x0, y0, z0, padding_mode, align_corner); - float v001 = get_value_bounded(image, x1, y0, z0, padding_mode, align_corner); - float v010 = get_value_bounded(image, x0, y1, z0, padding_mode, align_corner); - float v011 = get_value_bounded(image, x1, y1, z0, padding_mode, align_corner); - float v100 = get_value_bounded(image, x0, y0, z1, padding_mode, align_corner); - float v101 = get_value_bounded(image, x1, y0, z1, padding_mode, align_corner); - float v110 = get_value_bounded(image, x0, y1, z1, padding_mode, align_corner); - float v111 = get_value_bounded(image, x1, y1, z1, padding_mode, align_corner); + float v000 = get_value_bounded(image, x0, y0, z0); + float v001 = get_value_bounded(image, x1, y0, z0); + float v010 = get_value_bounded(image, x0, y1, z0); + float v011 = get_value_bounded(image, x1, y1, z0); + float v100 = get_value_bounded(image, x0, y0, z1); + float v101 = get_value_bounded(image, x1, y0, z1); + float v110 = get_value_bounded(image, x0, y1, z1); + float v111 = get_value_bounded(image, x1, y1, z1); float alpha = sample_x - x0; float beta = sample_y - y0; @@ -393,46 +522,47 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& outptr[0] = v; outptr += 1; - gridptr += 3; + offsetptr_x++; + offsetptr_y++; + offsetptr_z++; } } } } } - else if (sample_type == 2) // nearest + else if (sample_type == Interpolation_NEAREST) // nearest { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); + const float* offsetptr_z = offset_blob.channel(2); for (int z = 0; z < outd; z++) { - const float* gridptr = grid.channel(z); - for (int y = 0; y < outh; y++) { for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - float sample_z = gridptr[2]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); - sample_z = grid_sample_unormalize(d, sample_z, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; + float sample_z = *offsetptr_z; - int x0 = static_cast(round(sample_x)); - int y0 = static_cast(round(sample_y)); - int z0 = static_cast(round(sample_z)); + int x0 = static_cast(floor(sample_x + 0.5f)); + int y0 = static_cast(floor(sample_y + 0.5f)); + int z0 = static_cast(floor(sample_z + 0.5f)); - float v = get_value_bounded(image, x0, y0, z0, padding_mode, align_corner); + float v = get_value_bounded(image, x0, y0, z0); outptr[0] = v; outptr += 1; - gridptr += 3; + offsetptr_x++; + offsetptr_y++; + offsetptr_z++; } } } diff --git a/src/layer/gridsample.h b/src/layer/gridsample.h index 0ea540eb4ba..f6e17c9d2f4 100644 --- a/src/layer/gridsample.h +++ b/src/layer/gridsample.h @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -28,11 +28,27 @@ class GridSample : public Layer virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + enum InterpolationMode // 1=bilinear 2=nearest 3=bicubic + { + Interpolation_BILINEAR = 1, + Interpolation_NEAREST = 2, + Interpolation_BICUBIC = 3 + }; + + enum PaddingMode // 1=zeros 2=border 3=reflection + { + Padding_ZEROS = 1, + Padding_BORDER = 2, + Padding_REFLECTION = 3 + }; + public: // param int sample_type; // 1=bilinear 2=nearest 3=bicubic int padding_mode; // 1=zeros 2=border 3=reflection int align_corner; + + int permute_fusion; }; } // namespace ncnn diff --git a/src/layer/groupnorm.cpp b/src/layer/groupnorm.cpp index f07be96cb54..7d28024d5ab 100644 --- a/src/layer/groupnorm.cpp +++ b/src/layer/groupnorm.cpp @@ -14,8 +14,6 @@ #include "groupnorm.h" -#include - namespace ncnn { GroupNorm::GroupNorm() diff --git a/src/layer/gru.cpp b/src/layer/gru.cpp index 1f7ddaef4ac..b1ef2e0da45 100644 --- a/src/layer/gru.cpp +++ b/src/layer/gru.cpp @@ -14,8 +14,6 @@ #include "gru.h" -#include - namespace ncnn { GRU::GRU() diff --git a/src/layer/instancenorm.cpp b/src/layer/instancenorm.cpp index 259fd7b26e5..27dba6c2a6b 100644 --- a/src/layer/instancenorm.cpp +++ b/src/layer/instancenorm.cpp @@ -14,8 +14,6 @@ #include "instancenorm.h" -#include - namespace ncnn { InstanceNorm::InstanceNorm() diff --git a/src/layer/layernorm.cpp b/src/layer/layernorm.cpp index d1361dec644..a4ff036fb15 100644 --- a/src/layer/layernorm.cpp +++ b/src/layer/layernorm.cpp @@ -14,8 +14,6 @@ #include "layernorm.h" -#include - namespace ncnn { LayerNorm::LayerNorm() diff --git a/src/layer/log.cpp b/src/layer/log.cpp index 135cc4ebb38..422ebbb2207 100644 --- a/src/layer/log.cpp +++ b/src/layer/log.cpp @@ -14,8 +14,6 @@ #include "log.h" -#include - namespace ncnn { Log::Log() diff --git a/src/layer/loongarch/binaryop_loongarch.cpp b/src/layer/loongarch/binaryop_loongarch.cpp index 0250226dc60..33916d966aa 100644 --- a/src/layer/loongarch/binaryop_loongarch.cpp +++ b/src/layer/loongarch/binaryop_loongarch.cpp @@ -14,8 +14,6 @@ #include "binaryop_loongarch.h" -#include - #if __loongarch_sx #include #include "lsx_mathfun.h" diff --git a/src/layer/loongarch/interp_loongarch.cpp b/src/layer/loongarch/interp_loongarch.cpp index 94d25cf005e..7c47c108859 100644 --- a/src/layer/loongarch/interp_loongarch.cpp +++ b/src/layer/loongarch/interp_loongarch.cpp @@ -14,8 +14,6 @@ #include "interp_loongarch.h" -#include - #if __loongarch_sx #include #endif // __loongarch_sx diff --git a/src/layer/loongarch/loongarch_usability.h b/src/layer/loongarch/loongarch_usability.h index d3ae5dec279..0cd82e8fb45 100644 --- a/src/layer/loongarch/loongarch_usability.h +++ b/src/layer/loongarch/loongarch_usability.h @@ -19,7 +19,6 @@ #include #endif // __loongarch_sx -#include #include namespace ncnn { diff --git a/src/layer/loongarch/mish_loongarch.cpp b/src/layer/loongarch/mish_loongarch.cpp index 8558e2f8cb0..90e5ffe5484 100644 --- a/src/layer/loongarch/mish_loongarch.cpp +++ b/src/layer/loongarch/mish_loongarch.cpp @@ -19,8 +19,6 @@ #include "lsx_mathfun.h" #endif // __loongarch_sx -#include - namespace ncnn { Mish_loongarch::Mish_loongarch() diff --git a/src/layer/loongarch/quantize_loongarch.cpp b/src/layer/loongarch/quantize_loongarch.cpp index 657ff2d06bf..a0dd618771d 100644 --- a/src/layer/loongarch/quantize_loongarch.cpp +++ b/src/layer/loongarch/quantize_loongarch.cpp @@ -14,8 +14,6 @@ #include "quantize_loongarch.h" -#include - #if __loongarch_sx #include #endif // __loongarch_sx diff --git a/src/layer/loongarch/requantize_loongarch.cpp b/src/layer/loongarch/requantize_loongarch.cpp index 556d20de4f6..3399ac096b6 100644 --- a/src/layer/loongarch/requantize_loongarch.cpp +++ b/src/layer/loongarch/requantize_loongarch.cpp @@ -14,8 +14,6 @@ #include "requantize_loongarch.h" -#include - #if __loongarch_sx #include #endif // __loongarch_sx diff --git a/src/layer/loongarch/sigmoid_loongarch.cpp b/src/layer/loongarch/sigmoid_loongarch.cpp index 6d112804f26..c6f83c24708 100644 --- a/src/layer/loongarch/sigmoid_loongarch.cpp +++ b/src/layer/loongarch/sigmoid_loongarch.cpp @@ -21,8 +21,6 @@ #include "loongarch_usability.h" -#include - namespace ncnn { Sigmoid_loongarch::Sigmoid_loongarch() diff --git a/src/layer/loongarch/softmax_loongarch.cpp b/src/layer/loongarch/softmax_loongarch.cpp index 88b49559754..513f9a5e9ca 100644 --- a/src/layer/loongarch/softmax_loongarch.cpp +++ b/src/layer/loongarch/softmax_loongarch.cpp @@ -15,7 +15,6 @@ #include "softmax_loongarch.h" #include -#include #if __loongarch_sx #include diff --git a/src/layer/loongarch/swish_loongarch.cpp b/src/layer/loongarch/swish_loongarch.cpp index 9c9005de6fc..7e80339c937 100644 --- a/src/layer/loongarch/swish_loongarch.cpp +++ b/src/layer/loongarch/swish_loongarch.cpp @@ -19,8 +19,6 @@ #include "lsx_mathfun.h" #endif // __loongarch_sx -#include - namespace ncnn { Swish_loongarch::Swish_loongarch() diff --git a/src/layer/loongarch/tanh_loongarch.cpp b/src/layer/loongarch/tanh_loongarch.cpp index 13227fa71e3..b592c3f57b2 100644 --- a/src/layer/loongarch/tanh_loongarch.cpp +++ b/src/layer/loongarch/tanh_loongarch.cpp @@ -19,8 +19,6 @@ #include "lsx_mathfun.h" #endif // __loongarch_sx -#include - namespace ncnn { TanH_loongarch::TanH_loongarch() diff --git a/src/layer/loongarch/unaryop_loongarch.cpp b/src/layer/loongarch/unaryop_loongarch.cpp index 4d4818cb5af..95a4e9984b6 100644 --- a/src/layer/loongarch/unaryop_loongarch.cpp +++ b/src/layer/loongarch/unaryop_loongarch.cpp @@ -14,9 +14,8 @@ #include "unaryop_loongarch.h" -#include +// #include #include -#include #if __loongarch_sx #include diff --git a/src/layer/lrn.cpp b/src/layer/lrn.cpp index aaa8855135a..c18f1def9fb 100644 --- a/src/layer/lrn.cpp +++ b/src/layer/lrn.cpp @@ -14,8 +14,6 @@ #include "lrn.h" -#include - namespace ncnn { LRN::LRN() diff --git a/src/layer/lstm.cpp b/src/layer/lstm.cpp index f2aa19f25ab..c761a98d4dd 100644 --- a/src/layer/lstm.cpp +++ b/src/layer/lstm.cpp @@ -14,8 +14,6 @@ #include "lstm.h" -#include - namespace ncnn { LSTM::LSTM() diff --git a/src/layer/memorydata.cpp b/src/layer/memorydata.cpp index 6cd314d76b9..02a0e0be078 100644 --- a/src/layer/memorydata.cpp +++ b/src/layer/memorydata.cpp @@ -28,6 +28,7 @@ int MemoryData::load_param(const ParamDict& pd) h = pd.get(1, 0); d = pd.get(11, 0); c = pd.get(2, 0); + load_type = pd.get(21, 1); return 0; } @@ -36,19 +37,19 @@ int MemoryData::load_model(const ModelBin& mb) { if (d != 0) { - data = mb.load(w, h, d, c, 1); + data = mb.load(w, h, d, c, load_type); } else if (c != 0) { - data = mb.load(w, h, c, 1); + data = mb.load(w, h, c, load_type); } else if (h != 0) { - data = mb.load(w, h, 1); + data = mb.load(w, h, load_type); } else if (w != 0) { - data = mb.load(w, 1); + data = mb.load(w, load_type); } else // 0 0 0 { diff --git a/src/layer/memorydata.h b/src/layer/memorydata.h index 4b2c697912f..d5175ad0dd8 100644 --- a/src/layer/memorydata.h +++ b/src/layer/memorydata.h @@ -35,6 +35,7 @@ class MemoryData : public Layer int h; int d; int c; + int load_type; Mat data; }; diff --git a/src/layer/mips/binaryop_mips.cpp b/src/layer/mips/binaryop_mips.cpp index ab8bfe86ac3..188a0860508 100644 --- a/src/layer/mips/binaryop_mips.cpp +++ b/src/layer/mips/binaryop_mips.cpp @@ -14,8 +14,6 @@ #include "binaryop_mips.h" -#include - #if __mips_msa #include #include "msa_mathfun.h" diff --git a/src/layer/mips/interp_mips.cpp b/src/layer/mips/interp_mips.cpp index 7d77e9b9dbf..2cc3202e915 100644 --- a/src/layer/mips/interp_mips.cpp +++ b/src/layer/mips/interp_mips.cpp @@ -14,8 +14,6 @@ #include "interp_mips.h" -#include - #if __mips_msa #include #endif // __mips_msa diff --git a/src/layer/mips/mips_usability.h b/src/layer/mips/mips_usability.h index 4aee94e75a9..662320ee747 100644 --- a/src/layer/mips/mips_usability.h +++ b/src/layer/mips/mips_usability.h @@ -20,7 +20,6 @@ #include #endif // __mips_msa -#include #include namespace ncnn { diff --git a/src/layer/mips/mish_mips.cpp b/src/layer/mips/mish_mips.cpp index 3dc81450914..32f8a6e173c 100644 --- a/src/layer/mips/mish_mips.cpp +++ b/src/layer/mips/mish_mips.cpp @@ -19,8 +19,6 @@ #include "msa_mathfun.h" #endif // __mips_msa -#include - namespace ncnn { Mish_mips::Mish_mips() diff --git a/src/layer/mips/quantize_mips.cpp b/src/layer/mips/quantize_mips.cpp index a4b61601661..963d0908ce4 100644 --- a/src/layer/mips/quantize_mips.cpp +++ b/src/layer/mips/quantize_mips.cpp @@ -14,8 +14,6 @@ #include "quantize_mips.h" -#include - #if __mips_msa #include #endif // __mips_msa diff --git a/src/layer/mips/requantize_mips.cpp b/src/layer/mips/requantize_mips.cpp index 095f42084c9..44e55f89477 100644 --- a/src/layer/mips/requantize_mips.cpp +++ b/src/layer/mips/requantize_mips.cpp @@ -14,8 +14,6 @@ #include "requantize_mips.h" -#include - #if __mips_msa #include #endif // __mips_msa diff --git a/src/layer/mips/sigmoid_mips.cpp b/src/layer/mips/sigmoid_mips.cpp index af44f811364..b7f83f37bb2 100644 --- a/src/layer/mips/sigmoid_mips.cpp +++ b/src/layer/mips/sigmoid_mips.cpp @@ -21,8 +21,6 @@ #include "mips_usability.h" -#include - namespace ncnn { Sigmoid_mips::Sigmoid_mips() diff --git a/src/layer/mips/softmax_mips.cpp b/src/layer/mips/softmax_mips.cpp index ae35782da9f..f00b2849670 100644 --- a/src/layer/mips/softmax_mips.cpp +++ b/src/layer/mips/softmax_mips.cpp @@ -15,7 +15,6 @@ #include "softmax_mips.h" #include -#include #if __mips_msa #include diff --git a/src/layer/mips/swish_mips.cpp b/src/layer/mips/swish_mips.cpp index d3a7d032b55..6c6a368301d 100644 --- a/src/layer/mips/swish_mips.cpp +++ b/src/layer/mips/swish_mips.cpp @@ -19,8 +19,6 @@ #include "msa_mathfun.h" #endif // __mips_msa -#include - namespace ncnn { Swish_mips::Swish_mips() diff --git a/src/layer/mips/tanh_mips.cpp b/src/layer/mips/tanh_mips.cpp index c2197fb75d9..4546a98de63 100644 --- a/src/layer/mips/tanh_mips.cpp +++ b/src/layer/mips/tanh_mips.cpp @@ -19,8 +19,6 @@ #include "msa_mathfun.h" #endif // __mips_msa -#include - namespace ncnn { TanH_mips::TanH_mips() diff --git a/src/layer/mips/unaryop_mips.cpp b/src/layer/mips/unaryop_mips.cpp index b923535a2d8..cb3c115cd00 100644 --- a/src/layer/mips/unaryop_mips.cpp +++ b/src/layer/mips/unaryop_mips.cpp @@ -14,9 +14,8 @@ #include "unaryop_mips.h" -#include +// #include #include -#include #if __mips_msa #include diff --git a/src/layer/mish.cpp b/src/layer/mish.cpp index 8b2f16500c7..f27d112f445 100644 --- a/src/layer/mish.cpp +++ b/src/layer/mish.cpp @@ -14,8 +14,6 @@ #include "mish.h" -#include - namespace ncnn { Mish::Mish() diff --git a/src/layer/mvn.cpp b/src/layer/mvn.cpp index 773ace23c50..713fb1b4195 100644 --- a/src/layer/mvn.cpp +++ b/src/layer/mvn.cpp @@ -14,8 +14,6 @@ #include "mvn.h" -#include - namespace ncnn { MVN::MVN() diff --git a/src/layer/normalize.cpp b/src/layer/normalize.cpp index 2aa6109b187..a86851117c9 100644 --- a/src/layer/normalize.cpp +++ b/src/layer/normalize.cpp @@ -14,8 +14,6 @@ #include "normalize.h" -#include - namespace ncnn { Normalize::Normalize() diff --git a/src/layer/power.cpp b/src/layer/power.cpp index a25d23bfb63..8e4ef25852b 100644 --- a/src/layer/power.cpp +++ b/src/layer/power.cpp @@ -14,8 +14,6 @@ #include "power.h" -#include - namespace ncnn { Power::Power() diff --git a/src/layer/priorbox.cpp b/src/layer/priorbox.cpp index 82249a55f63..6e54ba0162d 100644 --- a/src/layer/priorbox.cpp +++ b/src/layer/priorbox.cpp @@ -14,8 +14,6 @@ #include "priorbox.h" -#include - namespace ncnn { PriorBox::PriorBox() diff --git a/src/layer/proposal.cpp b/src/layer/proposal.cpp index 908b60692da..a7dce35f6ee 100644 --- a/src/layer/proposal.cpp +++ b/src/layer/proposal.cpp @@ -14,8 +14,6 @@ #include "proposal.h" -#include - namespace ncnn { Proposal::Proposal() diff --git a/src/layer/psroipooling.cpp b/src/layer/psroipooling.cpp index ebe2ad800c6..c576e31161c 100644 --- a/src/layer/psroipooling.cpp +++ b/src/layer/psroipooling.cpp @@ -14,8 +14,6 @@ #include "psroipooling.h" -#include - namespace ncnn { PSROIPooling::PSROIPooling() diff --git a/src/layer/quantize.cpp b/src/layer/quantize.cpp index 54bfb836f52..a53cebdd9a0 100644 --- a/src/layer/quantize.cpp +++ b/src/layer/quantize.cpp @@ -14,8 +14,6 @@ #include "quantize.h" -#include - namespace ncnn { Quantize::Quantize() diff --git a/src/layer/reduction.cpp b/src/layer/reduction.cpp index f7c9013b8f4..4d4f7fb578b 100644 --- a/src/layer/reduction.cpp +++ b/src/layer/reduction.cpp @@ -16,7 +16,6 @@ #include #include -#include namespace ncnn { diff --git a/src/layer/requantize.cpp b/src/layer/requantize.cpp index 0bcbbff879f..e11fbc6b272 100644 --- a/src/layer/requantize.cpp +++ b/src/layer/requantize.cpp @@ -15,8 +15,6 @@ #include "requantize.h" -#include - namespace ncnn { static inline signed char float2int8(float v) diff --git a/src/layer/riscv/binaryop_riscv.cpp b/src/layer/riscv/binaryop_riscv.cpp index c3d4258dd5e..da4593197f4 100644 --- a/src/layer/riscv/binaryop_riscv.cpp +++ b/src/layer/riscv/binaryop_riscv.cpp @@ -17,8 +17,6 @@ #include "binaryop_riscv.h" -#include - #if __riscv_vector #include #include "rvv_mathfun.h" diff --git a/src/layer/riscv/gemm_riscv.cpp b/src/layer/riscv/gemm_riscv.cpp new file mode 100644 index 00000000000..ec5a5cdac41 --- /dev/null +++ b/src/layer/riscv/gemm_riscv.cpp @@ -0,0 +1,4282 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "gemm_riscv.h" + +#if __riscv_vector +#include +#endif // __riscv_vector + +#include "riscv_usability.h" + +#include "cpu.h" + +namespace ncnn { + +Gemm_riscv::Gemm_riscv() +{ +#if __riscv_vector + support_packing = true; +#endif // __riscv_vector + one_blob_only = false; + support_inplace = false; + + nT = 0; +#if __riscv_vector + // When processing float data, + // even if the current hardware provides vector registers of more than 128 bits, + // vl=4 is still used, even though this will waste the width of the vector register. + vl = vsetvlmax_e32m1(); + vl = vl >= 4 ? 4 : vl; +#else + vl = 0; +#endif // __riscv_vector +} + +static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, size_t vl) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + float* pp = AT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (elempack == 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * 4; + const float* p1 = (const float*)A + (i + ii + 4) * A_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p1, vl), vl); + pp += 8; + p0 += 4; + p1 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + const float* p4 = (const float*)A + (i + ii + 4) * A_hstep + k; + const float* p5 = (const float*)A + (i + ii + 5) * A_hstep + k; + const float* p6 = (const float*)A + (i + ii + 6) * A_hstep + k; + const float* p7 = (const float*)A + (i + ii + 7) * A_hstep + k; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vfloat32m1_t _r0l = vle32_v_f32m1(p0, vl); + vfloat32m1_t _r0h = vle32_v_f32m1(p0 + 4, vl); + vfloat32m1_t _r1l = vle32_v_f32m1(p1, vl); + vfloat32m1_t _r1h = vle32_v_f32m1(p1 + 4, vl); + vfloat32m1_t _r2l = vle32_v_f32m1(p2, vl); + vfloat32m1_t _r2h = vle32_v_f32m1(p2 + 4, vl); + vfloat32m1_t _r3l = vle32_v_f32m1(p3, vl); + vfloat32m1_t _r3h = vle32_v_f32m1(p3 + 4, vl); + vfloat32m1_t _r4l = vle32_v_f32m1(p4, vl); + vfloat32m1_t _r4h = vle32_v_f32m1(p4 + 4, vl); + vfloat32m1_t _r5l = vle32_v_f32m1(p5, vl); + vfloat32m1_t _r5h = vle32_v_f32m1(p5 + 4, vl); + vfloat32m1_t _r6l = vle32_v_f32m1(p6, vl); + vfloat32m1_t _r6h = vle32_v_f32m1(p6 + 4, vl); + vfloat32m1_t _r7l = vle32_v_f32m1(p7, vl); + vfloat32m1_t _r7h = vle32_v_f32m1(p7 + 4, vl); + transpose8x8_ps(_r0l, _r0h, _r1l, _r1h, _r2l, _r2h, _r3l, _r3h, _r4l, _r4h, _r5l, _r5h, _r6l, _r6h, _r7l, _r7h, vl); + vse32_v_f32m1(pp, _r0l, vl); + vse32_v_f32m1(pp + 4, _r0h, vl); + vse32_v_f32m1(pp + 8, _r1l, vl); + vse32_v_f32m1(pp + 12, _r1h, vl); + vse32_v_f32m1(pp + 8 * 2, _r2l, vl); + vse32_v_f32m1(pp + 8 * 2 + 4, _r2h, vl); + vse32_v_f32m1(pp + 8 * 3, _r3l, vl); + vse32_v_f32m1(pp + 8 * 3 + 4, _r3h, vl); + vse32_v_f32m1(pp + 8 * 4, _r4l, vl); + vse32_v_f32m1(pp + 8 * 4 + 4, _r4h, vl); + vse32_v_f32m1(pp + 8 * 5, _r5l, vl); + vse32_v_f32m1(pp + 8 * 5 + 4, _r5h, vl); + vse32_v_f32m1(pp + 8 * 6, _r6l, vl); + vse32_v_f32m1(pp + 8 * 6 + 4, _r6h, vl); + vse32_v_f32m1(pp + 8 * 7, _r7l, vl); + vse32_v_f32m1(pp + 8 * 7 + 4, _r7h, vl); + pp += 64; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + p4 += 8; + p5 += 8; + p6 += 8; + p7 += 8; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (elempack == 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p1, vl); + vfloat32m1_t v2 = vle32_v_f32m1(p2, vl); + vfloat32m1_t v3 = vle32_v_f32m1(p3, vl); + store_float_v4(v0, v1, v2, v3, pp, vl); + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { + // if (elempack == 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p1, vl); + store_float_v2(v0, v1, pp, vl); + pp += 8; + p0 += 4; + p1 += 4; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + } + for (; ii < max_ii; ii += 1) + { + // if (elempack == 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + 3 < max_kk; kk += 4) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += 4; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, size_t vl) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + float* pp = AT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (elempack == 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vfloat32m1_t _r4; + vfloat32m1_t _r5; + vfloat32m1_t _r6; + vfloat32m1_t _r7; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vlseg4e32_v_f32m1(&_r4, &_r5, &_r6, &_r7, p0 + 16, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r1, vl); + vse32_v_f32m1(pp + 4 * 3, _r5, vl); + vse32_v_f32m1(pp + 4 * 4, _r2, vl); + vse32_v_f32m1(pp + 4 * 5, _r6, vl); + vse32_v_f32m1(pp + 4 * 6, _r3, vl); + vse32_v_f32m1(pp + 4 * 7, _r7, vl); + pp += 32; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p0 + 4, vl), vl); + pp += 8; + p0 += A_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (elempack == 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r1, vl); + vse32_v_f32m1(pp + 4 * 2, _r2, vl); + vse32_v_f32m1(pp + 4 * 3, _r3, vl); + pp += 16; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { +#if __riscv_vector + if (elempack == 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p0 + 4, vl); + store_float_v2(v0, v1, pp, vl); + pp += 8; + p0 += A_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { +#if __riscv_vector + if (elempack == 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += A_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, size_t vl) +{ + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + float* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; + const float* p1 = (const float*)B + (j + jj + 4) * B_hstep + k * 4; + const float* p2 = (const float*)B + (j + jj + 8) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p1, vl), vl); + vse32_v_f32m1(pp + 8, vle32_v_f32m1(p2, vl), vl); + pp += 12; + p0 += 4; + p1 += 4; + p2 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + const float* p4 = (const float*)B + (j + jj + 4) * B_hstep + k; + const float* p5 = (const float*)B + (j + jj + 5) * B_hstep + k; + const float* p6 = (const float*)B + (j + jj + 6) * B_hstep + k; + const float* p7 = (const float*)B + (j + jj + 7) * B_hstep + k; + const float* p8 = (const float*)B + (j + jj + 8) * B_hstep + k; + const float* p9 = (const float*)B + (j + jj + 9) * B_hstep + k; + const float* pa = (const float*)B + (j + jj + 10) * B_hstep + k; + const float* pb = (const float*)B + (j + jj + 11) * B_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t _r1 = vle32_v_f32m1(p1, vl); + vfloat32m1_t _r2 = vle32_v_f32m1(p2, vl); + vfloat32m1_t _r3 = vle32_v_f32m1(p3, vl); + vfloat32m1_t _r4 = vle32_v_f32m1(p4, vl); + vfloat32m1_t _r5 = vle32_v_f32m1(p5, vl); + vfloat32m1_t _r6 = vle32_v_f32m1(p6, vl); + vfloat32m1_t _r7 = vle32_v_f32m1(p7, vl); + vfloat32m1_t _r8 = vle32_v_f32m1(p8, vl); + vfloat32m1_t _r9 = vle32_v_f32m1(p9, vl); + vfloat32m1_t _ra = vle32_v_f32m1(pa, vl); + vfloat32m1_t _rb = vle32_v_f32m1(pb, vl); + + transpose4x4_ps(_r0, _r1, _r2, _r3, vl); + transpose4x4_ps(_r4, _r5, _r6, _r7, vl); + transpose4x4_ps(_r8, _r9, _ra, _rb, vl); + + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r8, vl); + vse32_v_f32m1(pp + 4 * 3, _r1, vl); + vse32_v_f32m1(pp + 4 * 4, _r5, vl); + vse32_v_f32m1(pp + 4 * 5, _r9, vl); + vse32_v_f32m1(pp + 4 * 6, _r2, vl); + vse32_v_f32m1(pp + 4 * 7, _r6, vl); + vse32_v_f32m1(pp + 4 * 8, _ra, vl); + vse32_v_f32m1(pp + 4 * 9, _r3, vl); + vse32_v_f32m1(pp + 4 * 10, _r7, vl); + vse32_v_f32m1(pp + 4 * 11, _rb, vl); + pp += 48; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + p8 += 4; + p9 += 4; + pa += 4; + pb += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp[8] = p8[0]; + pp[9] = p9[0]; + pp[10] = pa[0]; + pp[11] = pb[0]; + pp += 12; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + p8++; + p9++; + pa++; + pb++; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; + const float* p1 = (const float*)B + (j + jj + 4) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p1, vl), vl); + pp += 8; + p0 += 4; + p1 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + const float* p4 = (const float*)B + (j + jj + 4) * B_hstep + k; + const float* p5 = (const float*)B + (j + jj + 5) * B_hstep + k; + const float* p6 = (const float*)B + (j + jj + 6) * B_hstep + k; + const float* p7 = (const float*)B + (j + jj + 7) * B_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t _r1 = vle32_v_f32m1(p1, vl); + vfloat32m1_t _r2 = vle32_v_f32m1(p2, vl); + vfloat32m1_t _r3 = vle32_v_f32m1(p3, vl); + vfloat32m1_t _r4 = vle32_v_f32m1(p4, vl); + vfloat32m1_t _r5 = vle32_v_f32m1(p5, vl); + vfloat32m1_t _r6 = vle32_v_f32m1(p6, vl); + vfloat32m1_t _r7 = vle32_v_f32m1(p7, vl); + + transpose4x4_ps(_r0, _r1, _r2, _r3, vl); + transpose4x4_ps(_r4, _r5, _r6, _r7, vl); + + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r1, vl); + vse32_v_f32m1(pp + 4 * 3, _r5, vl); + vse32_v_f32m1(pp + 4 * 4, _r2, vl); + vse32_v_f32m1(pp + 4 * 5, _r6, vl); + vse32_v_f32m1(pp + 4 * 6, _r3, vl); + vse32_v_f32m1(pp + 4 * 7, _r7, vl); + pp += 32; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p1, vl); + vfloat32m1_t v2 = vle32_v_f32m1(p2, vl); + vfloat32m1_t v3 = vle32_v_f32m1(p3, vl); + store_float_v4(v0, v1, v2, v3, pp, vl); + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + // if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p1, vl); + store_float_v2(v0, v1, pp, vl); + pp += 8; + p0 += 4; + p1 += 4; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + } + for (; jj < max_jj; jj += 1) + { + // if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + 3 < max_kk; kk += 4) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += 4; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, size_t vl) +{ + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + float* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vfloat32m1_t _r4; + vfloat32m1_t _r5; + vfloat32m1_t _r6; + vfloat32m1_t _r7; + vfloat32m1_t _r8; + vfloat32m1_t _r9; + vfloat32m1_t _ra; + vfloat32m1_t _rb; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vlseg4e32_v_f32m1(&_r4, &_r5, &_r6, &_r7, p0 + 16, vl); + vlseg4e32_v_f32m1(&_r8, &_r9, &_ra, &_rb, p0 + 32, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r8, vl); + vse32_v_f32m1(pp + 4 * 3, _r1, vl); + vse32_v_f32m1(pp + 4 * 4, _r5, vl); + vse32_v_f32m1(pp + 4 * 5, _r9, vl); + vse32_v_f32m1(pp + 4 * 6, _r2, vl); + vse32_v_f32m1(pp + 4 * 7, _r6, vl); + vse32_v_f32m1(pp + 4 * 8, _ra, vl); + vse32_v_f32m1(pp + 4 * 9, _r3, vl); + vse32_v_f32m1(pp + 4 * 10, _r7, vl); + vse32_v_f32m1(pp + 4 * 11, _rb, vl); + pp += 48; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p0 + 4, vl), vl); + vse32_v_f32m1(pp + 8, vle32_v_f32m1(p0 + 8, vl), vl); + pp += 12; + p0 += B_hstep; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vfloat32m1_t _r4; + vfloat32m1_t _r5; + vfloat32m1_t _r6; + vfloat32m1_t _r7; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vlseg4e32_v_f32m1(&_r4, &_r5, &_r6, &_r7, p0 + 16, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r1, vl); + vse32_v_f32m1(pp + 4 * 3, _r5, vl); + vse32_v_f32m1(pp + 4 * 4, _r2, vl); + vse32_v_f32m1(pp + 4 * 5, _r6, vl); + vse32_v_f32m1(pp + 4 * 6, _r3, vl); + vse32_v_f32m1(pp + 4 * 7, _r7, vl); + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p0 + 4, vl), vl); + pp += 8; + p0 += B_hstep; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r1, vl); + vse32_v_f32m1(pp + 4 * 2, _r2, vl); + vse32_v_f32m1(pp + 4 * 3, _r3, vl); + pp += 16; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { +#if __riscv_vector + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p0 + 4, vl); + store_float_v2(v0, v1, pp, vl); + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { +#if __riscv_vector + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj, size_t vl) +{ + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const float* pp = topT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (out_elempack == 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(pp, vl); + vfloat32m1_t v1 = vle32_v_f32m1(pp + 8, vl); + vfloat32m1_t v2 = vle32_v_f32m1(pp + 16, vl); + vfloat32m1_t v3 = vle32_v_f32m1(pp + 24, vl); + store_float_v4(v0, v1, v2, v3, p0, vl); + v0 = vle32_v_f32m1(pp + 4, vl); + v1 = vle32_v_f32m1(pp + 12, vl); + v2 = vle32_v_f32m1(pp + 20, vl); + v3 = vle32_v_f32m1(pp + 28, vl); + store_float_v4(v0, v1, v2, v3, p0 + 16, vl); + pp += 32; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + vfloat32m1_t _r0 = vle32_v_f32m1(pp, vl); + vfloat32m1_t _r1 = vle32_v_f32m1(pp + 4, vl); + vse32_v_f32m1(p0, _r0, vl); + vse32_v_f32m1(p0 + 4, _r1, vl); + pp += 8; + p0 += out_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (out_elempack == 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(pp, vl); + vfloat32m1_t v1 = vle32_v_f32m1(pp + 4, vl); + vfloat32m1_t v2 = vle32_v_f32m1(pp + 8, vl); + vfloat32m1_t v3 = vle32_v_f32m1(pp + 12, vl); + store_float_v4(v0, v1, v2, v3, p0, vl); + pp += 16; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + vfloat32m1_t _r0 = vle32_v_f32m1(pp, vl); + vse32_v_f32m1(p0, _r0, vl); + pp += 4; + p0 += out_hstep; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { +#if __riscv_vector + if (out_elempack == 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + p0[0] = pp[0]; + p0[1] = pp[2]; + p0[2] = pp[4]; + p0[3] = pp[6]; + p0[4] = pp[1]; + p0[5] = pp[3]; + p0[6] = pp[5]; + p0[7] = pp[7]; + pp += 8; + p0 += out_hstep * 4; + } + } +#endif // __riscv_vector + if (out_elempack == 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + p0[0] = pp[0]; + p0[1] = pp[1]; + pp += 2; + p0 += out_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { +#if __riscv_vector + if (out_elempack == 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _r0 = vle32_v_f32m1(pp, vl); + vse32_v_f32m1(p0, _r0, vl); + pp += 4; + p0 += out_hstep * 4; + } + } +#endif // __riscv_vector + if (out_elempack == 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + p0[0] = pp[0]; + pp += 1; + p0 += out_hstep; + } + } + } +} + +static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end, size_t vl) +{ + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const float* pAT = AT_tile; + const float* pBT = BT_tile; + const float* pC = CT_tile; + + float* outptr = topT_tile; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + vfloat32m1_t _sum20; + vfloat32m1_t _sum21; + vfloat32m1_t _sum30; + vfloat32m1_t _sum31; + vfloat32m1_t _sum40; + vfloat32m1_t _sum41; + vfloat32m1_t _sum50; + vfloat32m1_t _sum51; + vfloat32m1_t _sum60; + vfloat32m1_t _sum61; + vfloat32m1_t _sum70; + vfloat32m1_t _sum71; + vfloat32m1_t _sum80; + vfloat32m1_t _sum81; + vfloat32m1_t _sum90; + vfloat32m1_t _sum91; + vfloat32m1_t _suma0; + vfloat32m1_t _suma1; + vfloat32m1_t _sumb0; + vfloat32m1_t _sumb1; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + _sum20 = vfmv_v_f_f32m1(0.f, vl); + _sum21 = vfmv_v_f_f32m1(0.f, vl); + _sum30 = vfmv_v_f_f32m1(0.f, vl); + _sum31 = vfmv_v_f_f32m1(0.f, vl); + _sum40 = vfmv_v_f_f32m1(0.f, vl); + _sum41 = vfmv_v_f_f32m1(0.f, vl); + _sum50 = vfmv_v_f_f32m1(0.f, vl); + _sum51 = vfmv_v_f_f32m1(0.f, vl); + _sum60 = vfmv_v_f_f32m1(0.f, vl); + _sum61 = vfmv_v_f_f32m1(0.f, vl); + _sum70 = vfmv_v_f_f32m1(0.f, vl); + _sum71 = vfmv_v_f_f32m1(0.f, vl); + _sum80 = vfmv_v_f_f32m1(0.f, vl); + _sum81 = vfmv_v_f_f32m1(0.f, vl); + _sum90 = vfmv_v_f_f32m1(0.f, vl); + _sum91 = vfmv_v_f_f32m1(0.f, vl); + _suma0 = vfmv_v_f_f32m1(0.f, vl); + _suma1 = vfmv_v_f_f32m1(0.f, vl); + _sumb0 = vfmv_v_f_f32m1(0.f, vl); + _sumb1 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + _sum20 = vfmv_v_f_f32m1(pC[0], vl); + _sum21 = vfmv_v_f_f32m1(pC[0], vl); + _sum30 = vfmv_v_f_f32m1(pC[0], vl); + _sum31 = vfmv_v_f_f32m1(pC[0], vl); + _sum40 = vfmv_v_f_f32m1(pC[0], vl); + _sum41 = vfmv_v_f_f32m1(pC[0], vl); + _sum50 = vfmv_v_f_f32m1(pC[0], vl); + _sum51 = vfmv_v_f_f32m1(pC[0], vl); + _sum60 = vfmv_v_f_f32m1(pC[0], vl); + _sum61 = vfmv_v_f_f32m1(pC[0], vl); + _sum70 = vfmv_v_f_f32m1(pC[0], vl); + _sum71 = vfmv_v_f_f32m1(pC[0], vl); + _sum80 = vfmv_v_f_f32m1(pC[0], vl); + _sum81 = vfmv_v_f_f32m1(pC[0], vl); + _sum90 = vfmv_v_f_f32m1(pC[0], vl); + _sum91 = vfmv_v_f_f32m1(pC[0], vl); + _suma0 = vfmv_v_f_f32m1(pC[0], vl); + _suma1 = vfmv_v_f_f32m1(pC[0], vl); + _sumb0 = vfmv_v_f_f32m1(pC[0], vl); + _sumb1 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum20 = _sum00; + _sum21 = _sum01; + _sum30 = _sum00; + _sum31 = _sum01; + _sum40 = _sum00; + _sum41 = _sum01; + _sum50 = _sum00; + _sum51 = _sum01; + _sum60 = _sum00; + _sum61 = _sum01; + _sum70 = _sum00; + _sum71 = _sum01; + _sum80 = _sum00; + _sum81 = _sum01; + _sum90 = _sum00; + _sum91 = _sum01; + _suma0 = _sum00; + _suma1 = _sum01; + _sumb0 = _sum00; + _sumb1 = _sum01; + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4 * 1, vl); + _sum10 = vle32_v_f32m1(pC + 4 * 2, vl); + _sum11 = vle32_v_f32m1(pC + 4 * 3, vl); + _sum20 = vle32_v_f32m1(pC + 4 * 4, vl); + _sum21 = vle32_v_f32m1(pC + 4 * 5, vl); + _sum30 = vle32_v_f32m1(pC + 4 * 6, vl); + _sum31 = vle32_v_f32m1(pC + 4 * 7, vl); + _sum40 = vle32_v_f32m1(pC + 4 * 8, vl); + _sum41 = vle32_v_f32m1(pC + 4 * 9, vl); + _sum50 = vle32_v_f32m1(pC + 4 * 10, vl); + _sum51 = vle32_v_f32m1(pC + 4 * 11, vl); + _sum60 = vle32_v_f32m1(pC + 4 * 12, vl); + _sum61 = vle32_v_f32m1(pC + 4 * 13, vl); + _sum70 = vle32_v_f32m1(pC + 4 * 14, vl); + _sum71 = vle32_v_f32m1(pC + 4 * 15, vl); + _sum80 = vle32_v_f32m1(pC + 4 * 16, vl); + _sum81 = vle32_v_f32m1(pC + 4 * 17, vl); + _sum90 = vle32_v_f32m1(pC + 4 * 18, vl); + _sum91 = vle32_v_f32m1(pC + 4 * 19, vl); + _suma0 = vle32_v_f32m1(pC + 4 * 20, vl); + _suma1 = vle32_v_f32m1(pC + 4 * 21, vl); + _sumb0 = vle32_v_f32m1(pC + 4 * 22, vl); + _sumb1 = vle32_v_f32m1(pC + 4 * 23, vl); + pC += 96; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum20 = vfmv_v_f_f32m1(pC[2], vl); + _sum30 = vfmv_v_f_f32m1(pC[3], vl); + _sum40 = vfmv_v_f_f32m1(pC[4], vl); + _sum50 = vfmv_v_f_f32m1(pC[5], vl); + _sum60 = vfmv_v_f_f32m1(pC[6], vl); + _sum70 = vfmv_v_f_f32m1(pC[7], vl); + _sum80 = vfmv_v_f_f32m1(pC[8], vl); + _sum90 = vfmv_v_f_f32m1(pC[9], vl); + _suma0 = vfmv_v_f_f32m1(pC[10], vl); + _sumb0 = vfmv_v_f_f32m1(pC[11], vl); + _sum01 = _sum00; + _sum11 = _sum10; + _sum21 = _sum20; + _sum31 = _sum30; + _sum41 = _sum40; + _sum51 = _sum50; + _sum61 = _sum60; + _sum71 = _sum70; + _sum81 = _sum80; + _sum91 = _sum90; + _suma1 = _suma0; + _sumb1 = _sumb0; + pC += 12; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum10 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum11 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum20 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum21 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum30 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum31 = vle32_v_f32m1(outptr + 4 * 7, vl); + _sum40 = vle32_v_f32m1(outptr + 4 * 8, vl); + _sum41 = vle32_v_f32m1(outptr + 4 * 9, vl); + _sum50 = vle32_v_f32m1(outptr + 4 * 10, vl); + _sum51 = vle32_v_f32m1(outptr + 4 * 11, vl); + _sum60 = vle32_v_f32m1(outptr + 4 * 12, vl); + _sum61 = vle32_v_f32m1(outptr + 4 * 13, vl); + _sum70 = vle32_v_f32m1(outptr + 4 * 14, vl); + _sum71 = vle32_v_f32m1(outptr + 4 * 15, vl); + _sum80 = vle32_v_f32m1(outptr + 4 * 16, vl); + _sum81 = vle32_v_f32m1(outptr + 4 * 17, vl); + _sum90 = vle32_v_f32m1(outptr + 4 * 18, vl); + _sum91 = vle32_v_f32m1(outptr + 4 * 19, vl); + _suma0 = vle32_v_f32m1(outptr + 4 * 20, vl); + _suma1 = vle32_v_f32m1(outptr + 4 * 21, vl); + _sumb0 = vle32_v_f32m1(outptr + 4 * 22, vl); + _sumb1 = vle32_v_f32m1(outptr + 4 * 23, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + + _pA0 = vle32_v_f32m1(pA, vl); + _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + + _pA0 = vle32_v_f32m1(pA, vl); + _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + + _pA0 = vle32_v_f32m1(pA, vl); + _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + } + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum10, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); + vse32_v_f32m1(outptr0 + 4 * 4, _sum40, vl); + vse32_v_f32m1(outptr0 + 4 * 5, _sum50, vl); + vse32_v_f32m1(outptr0 + 4 * 6, _sum60, vl); + vse32_v_f32m1(outptr0 + 4 * 7, _sum70, vl); + vse32_v_f32m1(outptr0 + 4 * 8, _sum80, vl); + vse32_v_f32m1(outptr0 + 4 * 9, _sum90, vl); + vse32_v_f32m1(outptr0 + 4 * 10, _suma0, vl); + vse32_v_f32m1(outptr0 + 4 * 11, _sumb0, vl); + + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 4, _sum41, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 5, _sum51, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 6, _sum61, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 7, _sum71, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 8, _sum81, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 9, _sum91, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 10, _suma1, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 11, _sumb1, vl); + + outptr0 += 48; + } + if (out_elempack == 1) + { + transpose8x12_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, _sum40, _sum41, _sum50, _sum51, _sum60, _sum61, _sum70, _sum71, _sum80, _sum81, _sum90, _sum91, _suma0, _suma1, _sumb0, _sumb1, vl); + + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum01, vl); + vse32_v_f32m1(outptr0 + 8, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum20, vl); + vse32_v_f32m1(outptr0 + out_hstep + 8, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum30, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum31, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 8, _sum40, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum41, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum50, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 8, _sum51, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum60, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum61, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 8, _sum70, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5, _sum71, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5 + 4, _sum80, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5 + 8, _sum81, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6, _sum90, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6 + 4, _sum91, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6 + 8, _suma0, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7, _suma1, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7 + 4, _sumb0, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7 + 8, _sumb1, vl); + + outptr0 += 12; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); + vse32_v_f32m1(outptr + 4 * 8, _sum40, vl); + vse32_v_f32m1(outptr + 4 * 9, _sum41, vl); + vse32_v_f32m1(outptr + 4 * 10, _sum50, vl); + vse32_v_f32m1(outptr + 4 * 11, _sum51, vl); + vse32_v_f32m1(outptr + 4 * 12, _sum60, vl); + vse32_v_f32m1(outptr + 4 * 13, _sum61, vl); + vse32_v_f32m1(outptr + 4 * 14, _sum70, vl); + vse32_v_f32m1(outptr + 4 * 15, _sum71, vl); + vse32_v_f32m1(outptr + 4 * 16, _sum80, vl); + vse32_v_f32m1(outptr + 4 * 17, _sum81, vl); + vse32_v_f32m1(outptr + 4 * 18, _sum90, vl); + vse32_v_f32m1(outptr + 4 * 19, _sum91, vl); + vse32_v_f32m1(outptr + 4 * 20, _suma0, vl); + vse32_v_f32m1(outptr + 4 * 21, _suma1, vl); + vse32_v_f32m1(outptr + 4 * 22, _sumb0, vl); + vse32_v_f32m1(outptr + 4 * 23, _sumb1, vl); + } + + outptr += 96; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + vfloat32m1_t _sum20; + vfloat32m1_t _sum21; + vfloat32m1_t _sum30; + vfloat32m1_t _sum31; + vfloat32m1_t _sum40; + vfloat32m1_t _sum41; + vfloat32m1_t _sum50; + vfloat32m1_t _sum51; + vfloat32m1_t _sum60; + vfloat32m1_t _sum61; + vfloat32m1_t _sum70; + vfloat32m1_t _sum71; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + _sum20 = vfmv_v_f_f32m1(0.f, vl); + _sum21 = vfmv_v_f_f32m1(0.f, vl); + _sum30 = vfmv_v_f_f32m1(0.f, vl); + _sum31 = vfmv_v_f_f32m1(0.f, vl); + _sum40 = vfmv_v_f_f32m1(0.f, vl); + _sum41 = vfmv_v_f_f32m1(0.f, vl); + _sum50 = vfmv_v_f_f32m1(0.f, vl); + _sum51 = vfmv_v_f_f32m1(0.f, vl); + _sum60 = vfmv_v_f_f32m1(0.f, vl); + _sum61 = vfmv_v_f_f32m1(0.f, vl); + _sum70 = vfmv_v_f_f32m1(0.f, vl); + _sum71 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + _sum20 = vfmv_v_f_f32m1(pC[0], vl); + _sum21 = vfmv_v_f_f32m1(pC[0], vl); + _sum30 = vfmv_v_f_f32m1(pC[0], vl); + _sum31 = vfmv_v_f_f32m1(pC[0], vl); + _sum40 = vfmv_v_f_f32m1(pC[0], vl); + _sum41 = vfmv_v_f_f32m1(pC[0], vl); + _sum50 = vfmv_v_f_f32m1(pC[0], vl); + _sum51 = vfmv_v_f_f32m1(pC[0], vl); + _sum60 = vfmv_v_f_f32m1(pC[0], vl); + _sum61 = vfmv_v_f_f32m1(pC[0], vl); + _sum70 = vfmv_v_f_f32m1(pC[0], vl); + _sum71 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum20 = _sum00; + _sum21 = _sum01; + _sum30 = _sum00; + _sum31 = _sum01; + _sum40 = _sum00; + _sum41 = _sum01; + _sum50 = _sum00; + _sum51 = _sum01; + _sum60 = _sum00; + _sum61 = _sum01; + _sum70 = _sum00; + _sum71 = _sum01; + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4 * 1, vl); + _sum10 = vle32_v_f32m1(pC + 4 * 2, vl); + _sum11 = vle32_v_f32m1(pC + 4 * 3, vl); + _sum20 = vle32_v_f32m1(pC + 4 * 4, vl); + _sum21 = vle32_v_f32m1(pC + 4 * 5, vl); + _sum30 = vle32_v_f32m1(pC + 4 * 6, vl); + _sum31 = vle32_v_f32m1(pC + 4 * 7, vl); + _sum40 = vle32_v_f32m1(pC + 4 * 8, vl); + _sum41 = vle32_v_f32m1(pC + 4 * 9, vl); + _sum50 = vle32_v_f32m1(pC + 4 * 10, vl); + _sum51 = vle32_v_f32m1(pC + 4 * 11, vl); + _sum60 = vle32_v_f32m1(pC + 4 * 12, vl); + _sum61 = vle32_v_f32m1(pC + 4 * 13, vl); + _sum70 = vle32_v_f32m1(pC + 4 * 14, vl); + _sum71 = vle32_v_f32m1(pC + 4 * 15, vl); + pC += 64; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum20 = vfmv_v_f_f32m1(pC[2], vl); + _sum30 = vfmv_v_f_f32m1(pC[3], vl); + _sum40 = vfmv_v_f_f32m1(pC[4], vl); + _sum50 = vfmv_v_f_f32m1(pC[5], vl); + _sum60 = vfmv_v_f_f32m1(pC[6], vl); + _sum70 = vfmv_v_f_f32m1(pC[7], vl); + _sum01 = _sum00; + _sum11 = _sum10; + _sum21 = _sum20; + _sum31 = _sum30; + _sum41 = _sum40; + _sum51 = _sum50; + _sum61 = _sum60; + _sum71 = _sum70; + pC += 8; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum10 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum11 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum20 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum21 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum30 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum31 = vle32_v_f32m1(outptr + 4 * 7, vl); + _sum40 = vle32_v_f32m1(outptr + 4 * 8, vl); + _sum41 = vle32_v_f32m1(outptr + 4 * 9, vl); + _sum50 = vle32_v_f32m1(outptr + 4 * 10, vl); + _sum51 = vle32_v_f32m1(outptr + 4 * 11, vl); + _sum60 = vle32_v_f32m1(outptr + 4 * 12, vl); + _sum61 = vle32_v_f32m1(outptr + 4 * 13, vl); + _sum70 = vle32_v_f32m1(outptr + 4 * 14, vl); + _sum71 = vle32_v_f32m1(outptr + 4 * 15, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + + pA += 8; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum10, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); + vse32_v_f32m1(outptr0 + 4 * 4, _sum40, vl); + vse32_v_f32m1(outptr0 + 4 * 5, _sum50, vl); + vse32_v_f32m1(outptr0 + 4 * 6, _sum60, vl); + vse32_v_f32m1(outptr0 + 4 * 7, _sum70, vl); + + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 4, _sum41, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 5, _sum51, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 6, _sum61, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 7, _sum71, vl); + + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose8x8_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, _sum40, _sum41, _sum50, _sum51, _sum60, _sum61, _sum70, _sum71, vl); + + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum20, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum30, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum31, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum40, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum41, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5, _sum50, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5 + 4, _sum51, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6, _sum60, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6 + 4, _sum61, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7, _sum70, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7 + 4, _sum71, vl); + + outptr0 += 8; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); + vse32_v_f32m1(outptr + 4 * 8, _sum40, vl); + vse32_v_f32m1(outptr + 4 * 9, _sum41, vl); + vse32_v_f32m1(outptr + 4 * 10, _sum50, vl); + vse32_v_f32m1(outptr + 4 * 11, _sum51, vl); + vse32_v_f32m1(outptr + 4 * 12, _sum60, vl); + vse32_v_f32m1(outptr + 4 * 13, _sum61, vl); + vse32_v_f32m1(outptr + 4 * 14, _sum70, vl); + vse32_v_f32m1(outptr + 4 * 15, _sum71, vl); + } + + outptr += 64; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + vfloat32m1_t _sum20; + vfloat32m1_t _sum21; + vfloat32m1_t _sum30; + vfloat32m1_t _sum31; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + _sum20 = vfmv_v_f_f32m1(0.f, vl); + _sum21 = vfmv_v_f_f32m1(0.f, vl); + _sum30 = vfmv_v_f_f32m1(0.f, vl); + _sum31 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + _sum20 = vfmv_v_f_f32m1(pC[0], vl); + _sum21 = vfmv_v_f_f32m1(pC[0], vl); + _sum30 = vfmv_v_f_f32m1(pC[0], vl); + _sum31 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum20 = _sum00; + _sum21 = _sum01; + _sum30 = _sum00; + _sum31 = _sum01; + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4 * 1, vl); + _sum10 = vle32_v_f32m1(pC + 4 * 2, vl); + _sum11 = vle32_v_f32m1(pC + 4 * 3, vl); + _sum20 = vle32_v_f32m1(pC + 4 * 4, vl); + _sum21 = vle32_v_f32m1(pC + 4 * 5, vl); + _sum30 = vle32_v_f32m1(pC + 4 * 6, vl); + _sum31 = vle32_v_f32m1(pC + 4 * 7, vl); + pC += 32; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum20 = vfmv_v_f_f32m1(pC[2], vl); + _sum30 = vfmv_v_f_f32m1(pC[3], vl); + _sum01 = _sum00; + _sum11 = _sum10; + _sum21 = _sum20; + _sum31 = _sum30; + pC += 4; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum10 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum11 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum20 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum21 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum30 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum31 = vle32_v_f32m1(outptr + 4 * 7, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + + pA += 8; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum10, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); + + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); + + outptr0 += 16; + } + if (out_elempack == 1) + { + transpose8x4_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, vl); + + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + out_hstep * 1, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum20, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6, _sum30, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7, _sum31, vl); + + outptr0 += 4; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); + } + + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4 * 1, vl); + _sum10 = vle32_v_f32m1(pC + 4 * 2, vl); + _sum11 = vle32_v_f32m1(pC + 4 * 3, vl); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum01 = _sum00; + _sum11 = _sum10; + pC += 2; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum10 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum11 = vle32_v_f32m1(outptr + 4 * 3, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + + pA += 8; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum10, vl); + + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); + outptr0 += 8; + } + if (out_elempack == 1) + { + float sum0[8]; + float sum1[8]; + vse32_v_f32m1(sum0, _sum00, vl); + vse32_v_f32m1(sum0 + 4, _sum01, vl); + vse32_v_f32m1(sum1, _sum10, vl); + vse32_v_f32m1(sum1 + 4, _sum11, vl); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0[out_hstep * 4 + 1] = sum1[4]; + outptr0[out_hstep * 5 + 1] = sum1[5]; + outptr0[out_hstep * 6 + 1] = sum1[6]; + outptr0[out_hstep * 7 + 1] = sum1[7]; + outptr0 += 2; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); + } + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = _sum00; + pC += 1; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + vfloat32m1_t _pB = vfmv_v_f_f32m1(pB[0], vl); + + _sum00 = vfmadd_vv_f32m1(_pA0, _pB, _sum00, vl); + _sum01 = vfmadd_vv_f32m1(_pA1, _pB, _sum01, vl); + + pA += 8; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + outptr0 += 4; + } + if (out_elempack == 1) + { + float sum0[8]; + vse32_v_f32m1(sum0, _sum00, vl); + vse32_v_f32m1(sum0 + 4, _sum01, vl); + + outptr0[0] = sum0[0]; + outptr0[out_hstep * 1] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + outptr0++; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + } + + outptr += 8; + } + + pAT += max_kk * 8; + } + for (; ii + 3 < max_ii; ii += 4) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + vfloat32m1_t _sum4; + vfloat32m1_t _sum5; + vfloat32m1_t _sum6; + vfloat32m1_t _sum7; + vfloat32m1_t _sum8; + vfloat32m1_t _sum9; + vfloat32m1_t _suma; + vfloat32m1_t _sumb; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + _sum2 = vfmv_v_f_f32m1(0.f, vl); + _sum3 = vfmv_v_f_f32m1(0.f, vl); + _sum4 = vfmv_v_f_f32m1(0.f, vl); + _sum5 = vfmv_v_f_f32m1(0.f, vl); + _sum6 = vfmv_v_f_f32m1(0.f, vl); + _sum7 = vfmv_v_f_f32m1(0.f, vl); + _sum8 = vfmv_v_f_f32m1(0.f, vl); + _sum9 = vfmv_v_f_f32m1(0.f, vl); + _suma = vfmv_v_f_f32m1(0.f, vl); + _sumb = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + _sum2 = vfmv_v_f_f32m1(pC[0], vl); + _sum3 = vfmv_v_f_f32m1(pC[0], vl); + _sum4 = vfmv_v_f_f32m1(pC[0], vl); + _sum5 = vfmv_v_f_f32m1(pC[0], vl); + _sum6 = vfmv_v_f_f32m1(pC[0], vl); + _sum7 = vfmv_v_f_f32m1(pC[0], vl); + _sum8 = vfmv_v_f_f32m1(pC[0], vl); + _sum9 = vfmv_v_f_f32m1(pC[0], vl); + _suma = vfmv_v_f_f32m1(pC[0], vl); + _sumb = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + _sum2 = vle32_v_f32m1(pC + 8, vl); + _sum3 = vle32_v_f32m1(pC + 12, vl); + _sum4 = vle32_v_f32m1(pC + 16, vl); + _sum5 = vle32_v_f32m1(pC + 20, vl); + _sum6 = vle32_v_f32m1(pC + 24, vl); + _sum7 = vle32_v_f32m1(pC + 28, vl); + _sum8 = vle32_v_f32m1(pC + 32, vl); + _sum9 = vle32_v_f32m1(pC + 36, vl); + _suma = vle32_v_f32m1(pC + 40, vl); + _sumb = vle32_v_f32m1(pC + 44, vl); + pC += 48; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + _sum2 = vfmv_v_f_f32m1(pC[2], vl); + _sum3 = vfmv_v_f_f32m1(pC[3], vl); + _sum4 = vfmv_v_f_f32m1(pC[4], vl); + _sum5 = vfmv_v_f_f32m1(pC[5], vl); + _sum6 = vfmv_v_f_f32m1(pC[6], vl); + _sum7 = vfmv_v_f_f32m1(pC[7], vl); + _sum8 = vfmv_v_f_f32m1(pC[8], vl); + _sum9 = vfmv_v_f_f32m1(pC[9], vl); + _suma = vfmv_v_f_f32m1(pC[10], vl); + _sumb = vfmv_v_f_f32m1(pC[11], vl); + pC += 12; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum4 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum5 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum6 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum7 = vle32_v_f32m1(outptr + 4 * 7, vl); + _sum8 = vle32_v_f32m1(outptr + 4 * 8, vl); + _sum9 = vle32_v_f32m1(outptr + 4 * 9, vl); + _suma = vle32_v_f32m1(outptr + 4 * 10, vl); + _sumb = vle32_v_f32m1(outptr + 4 * 11, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + + _sum0 = vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + _sum4 = vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); + _sum5 = vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); + _sum6 = vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); + _sum7 = vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); + _sum8 = vfmadd_vf_f32m1(_pA, pB[8], _sum8, vl); + _sum9 = vfmadd_vf_f32m1(_pA, pB[9], _sum9, vl); + _suma = vfmadd_vf_f32m1(_pA, pB[10], _suma, vl); + _sumb = vfmadd_vf_f32m1(_pA, pB[11], _sumb, vl); + + pA += 4; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); + vse32_v_f32m1(outptr0 + 4 * 4, _sum4, vl); + vse32_v_f32m1(outptr0 + 4 * 5, _sum5, vl); + vse32_v_f32m1(outptr0 + 4 * 6, _sum6, vl); + vse32_v_f32m1(outptr0 + 4 * 7, _sum7, vl); + vse32_v_f32m1(outptr0 + 4 * 8, _sum8, vl); + vse32_v_f32m1(outptr0 + 4 * 9, _sum9, vl); + vse32_v_f32m1(outptr0 + 4 * 10, _suma, vl); + vse32_v_f32m1(outptr0 + 4 * 11, _sumb, vl); + outptr0 += 48; + } + if (out_elempack == 1) + { + transpose4x12_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb, vl); + + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 8, _sum2, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum3, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum4, vl); + vse32_v_f32m1(outptr0 + out_hstep + 8, _sum5, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum6, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum7, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 8, _sum8, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum9, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _suma, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 8, _sumb, vl); + outptr0 += 12; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum4, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum5, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum6, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum7, vl); + vse32_v_f32m1(outptr + 4 * 8, _sum8, vl); + vse32_v_f32m1(outptr + 4 * 9, _sum9, vl); + vse32_v_f32m1(outptr + 4 * 10, _suma, vl); + vse32_v_f32m1(outptr + 4 * 11, _sumb, vl); + } + + outptr += 48; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + vfloat32m1_t _sum4; + vfloat32m1_t _sum5; + vfloat32m1_t _sum6; + vfloat32m1_t _sum7; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + _sum2 = vfmv_v_f_f32m1(0.f, vl); + _sum3 = vfmv_v_f_f32m1(0.f, vl); + _sum4 = vfmv_v_f_f32m1(0.f, vl); + _sum5 = vfmv_v_f_f32m1(0.f, vl); + _sum6 = vfmv_v_f_f32m1(0.f, vl); + _sum7 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + _sum2 = vfmv_v_f_f32m1(pC[0], vl); + _sum3 = vfmv_v_f_f32m1(pC[0], vl); + _sum4 = vfmv_v_f_f32m1(pC[0], vl); + _sum5 = vfmv_v_f_f32m1(pC[0], vl); + _sum6 = vfmv_v_f_f32m1(pC[0], vl); + _sum7 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + _sum2 = vle32_v_f32m1(pC + 8, vl); + _sum3 = vle32_v_f32m1(pC + 12, vl); + _sum4 = vle32_v_f32m1(pC + 16, vl); + _sum5 = vle32_v_f32m1(pC + 20, vl); + _sum6 = vle32_v_f32m1(pC + 24, vl); + _sum7 = vle32_v_f32m1(pC + 28, vl); + pC += 32; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + _sum2 = vfmv_v_f_f32m1(pC[2], vl); + _sum3 = vfmv_v_f_f32m1(pC[3], vl); + _sum4 = vfmv_v_f_f32m1(pC[4], vl); + _sum5 = vfmv_v_f_f32m1(pC[5], vl); + _sum6 = vfmv_v_f_f32m1(pC[6], vl); + _sum7 = vfmv_v_f_f32m1(pC[7], vl); + pC += 8; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum4 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum5 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum6 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum7 = vle32_v_f32m1(outptr + 4 * 7, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + + _sum0 = vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + _sum4 = vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); + _sum5 = vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); + _sum6 = vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); + _sum7 = vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); + + pA += 4; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); + vse32_v_f32m1(outptr0 + 4 * 4, _sum4, vl); + vse32_v_f32m1(outptr0 + 4 * 5, _sum5, vl); + vse32_v_f32m1(outptr0 + 4 * 6, _sum6, vl); + vse32_v_f32m1(outptr0 + 4 * 7, _sum7, vl); + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose4x8_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, vl); + + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum2, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum3, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum4, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum5, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum6, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum7, vl); + outptr0 += 8; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum4, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum5, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum6, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum7, vl); + } + + outptr += 32; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + _sum2 = vfmv_v_f_f32m1(0.f, vl); + _sum3 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + _sum2 = vfmv_v_f_f32m1(pC[0], vl); + _sum3 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + _sum2 = vle32_v_f32m1(pC + 8, vl); + _sum3 = vle32_v_f32m1(pC + 12, vl); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + _sum2 = vfmv_v_f_f32m1(pC[2], vl); + _sum3 = vfmv_v_f_f32m1(pC[3], vl); + pC += 4; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m1(outptr + 4 * 3, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + + _sum0 = vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + pA += 4; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); + outptr0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_ps(_sum0, _sum1, _sum2, _sum3, vl); + + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + out_hstep * 1, _sum1, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum2, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum3, vl); + outptr0 += 4; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); + } + + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + pC += 2; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + + _sum0 = vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + + pA += 4; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + outptr0 += 8; + } + if (out_elempack == 1) + { + float sum0[4]; + float sum1[4]; + vse32_v_f32m1(sum0, _sum0, vl); + vse32_v_f32m1(sum1, _sum1, vl); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0 += 2; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + } + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + vfloat32m1_t _sum0; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + pC += 4; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + pC += 1; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pB = vfmv_v_f_f32m1(pB[0], vl); + + _sum0 = vfmadd_vv_f32m1(_pA, _pB, _sum0, vl); + + pA += 4; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + outptr0 += 4; + } + if (out_elempack == 1) + { + float sum0[4]; + vse32_v_f32m1(sum0, _sum0, vl); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0++; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + } + + outptr += 4; + } + + pAT += max_kk * 4; + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum02; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + vfloat32m1_t _sum12; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum02 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + _sum12 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum02 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + _sum12 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum02 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum11 = vfmv_v_f_f32m1(pC[1], vl); + _sum12 = vfmv_v_f_f32m1(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vlseg2e32_v_f32m1(&_sum00, &_sum10, pC, vl); + vlseg2e32_v_f32m1(&_sum01, &_sum11, pC + 8, vl); + vlseg2e32_v_f32m1(&_sum02, &_sum12, pC + 16, vl); + pC += 24; + } + if (broadcast_type_C == 4) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum02 = vle32_v_f32m1(pC + 8, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum12 = _sum02; + pC += 12; + } + } + } + else + { + vlseg2e32_v_f32m1(&_sum00, &_sum10, outptr, vl); + vlseg2e32_v_f32m1(&_sum01, &_sum11, outptr + 8, vl); + vlseg2e32_v_f32m1(&_sum02, &_sum12, outptr + 16, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB0 = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB1 = vle32_v_f32m1(pB + 4, vl); + vfloat32m1_t _pB2 = vle32_v_f32m1(pB + 8, vl); + + _sum00 = vfmadd_vf_f32m1(_pB0, pA[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pB1, pA[0], _sum01, vl); + _sum02 = vfmadd_vf_f32m1(_pB2, pA[0], _sum02, vl); + _sum10 = vfmadd_vf_f32m1(_pB0, pA[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pB1, pA[1], _sum11, vl); + _sum12 = vfmadd_vf_f32m1(_pB2, pA[1], _sum12, vl); + + pA += 2; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum01, vl); + vse32_v_f32m1(outptr0 + 8, _sum02, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep + 8, _sum12, vl); + outptr0 += 12; + } + } + else + { + store_float_v2(_sum00, _sum10, outptr, vl); + store_float_v2(_sum01, _sum11, outptr + 8, vl); + store_float_v2(_sum02, _sum12, outptr + 16, vl); + } + + outptr += 24; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum11 = vfmv_v_f_f32m1(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vlseg2e32_v_f32m1(&_sum00, &_sum10, pC, vl); + vlseg2e32_v_f32m1(&_sum01, &_sum11, pC + 8, vl); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + pC += 8; + } + } + } + else + { + vlseg2e32_v_f32m1(&_sum00, &_sum10, outptr, vl); + vlseg2e32_v_f32m1(&_sum01, &_sum11, outptr + 8, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB0 = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB1 = vle32_v_f32m1(pB + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pB0, pA[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pB1, pA[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pB0, pA[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pB1, pA[1], _sum11, vl); + pA += 2; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); + outptr0 += 8; + } + } + else + { + store_float_v2(_sum00, _sum10, outptr, vl); + store_float_v2(_sum01, _sum11, outptr + 8, vl); + } + + outptr += 16; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vlseg2e32_v_f32m1(&_sum0, &_sum1, pC, vl); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + pC += 4; + } + } + } + else + { + vfloat32m1_t _tmp0; + vfloat32m1_t _tmp1; + vlseg2e32_v_f32m1(&_tmp0, &_tmp1, outptr, vl); + _sum0 = _tmp0; + _sum1 = _tmp1; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB = vle32_v_f32m1(pB, vl); + + _sum0 = vfmadd_vf_f32m1(_pB, pA[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pB, pA[1], _sum1, vl); + + pA += 2; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum1, vl); + outptr0 += 4; + } + } + else + { + store_float_v2(_sum0, _sum1, outptr, vl); + } + + outptr += 8; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + float sum00; + float sum01; + float sum10; + float sum11; + + if (k == 0) + { + sum00 = 0.f; + sum01 = 0.f; + sum10 = 0.f; + sum11 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0) + { + sum00 = pC[0]; + sum01 = pC[0]; + sum10 = pC[0]; + sum11 = pC[0]; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[0]; + sum11 = pC[1]; + } + if (broadcast_type_C == 3) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[2]; + sum11 = pC[3]; + pC += 4; + } + if (broadcast_type_C == 4) + { + sum00 = pC[0]; + sum01 = pC[0]; + sum10 = pC[1]; + sum11 = pC[1]; + pC += 2; + } + } + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + + pA += 2; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum00; + outptr0[1] = sum10; + outptr0[out_hstep] = sum01; + outptr0[out_hstep + 1] = sum11; + outptr0 += 2; + } + } + else + { + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + } + + outptr += 4; + } + for (; jj < max_jj; jj += 1) + { + float sum0; + float sum1; + + if (k == 0) + { + sum0 = 0.f; + sum1 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum0 = pC[0]; + sum1 = pC[1]; + } + if (broadcast_type_C == 3) + { + sum0 = pC[0]; + sum1 = pC[1]; + pC += 2; + } + if (broadcast_type_C == 4) + { + sum0 = pC[0]; + sum1 = pC[0]; + pC += 1; + } + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[out_hstep] = sum1; + outptr0++; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + _sum2 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + _sum2 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + _sum2 = vle32_v_f32m1(pC + 8, vl); + pC += 12; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4, vl); + _sum2 = vle32_v_f32m1(outptr + 8, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB0 = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB1 = vle32_v_f32m1(pB + 4, vl); + vfloat32m1_t _pB2 = vle32_v_f32m1(pB + 8, vl); + + vfloat32m1_t _pA0 = vfmv_v_f_f32m1(pA[0], vl); + + _sum0 = vfmadd_vv_f32m1(_pA0, _pB0, _sum0, vl); + _sum1 = vfmadd_vv_f32m1(_pA0, _pB1, _sum1, vl); + _sum2 = vfmadd_vv_f32m1(_pA0, _pB2, _sum2, vl); + + pA += 1; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 8, _sum2, vl); + outptr0 += 12; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + vse32_v_f32m1(outptr + 8, _sum2, vl); + } + + outptr += 12; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + pC += 8; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB0 = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB1 = vle32_v_f32m1(pB + 4, vl); + + vfloat32m1_t _pA0 = vfmv_v_f_f32m1(pA[0], vl); + _sum0 = vfmadd_vv_f32m1(_pA0, _pB0, _sum0, vl); + _sum1 = vfmadd_vv_f32m1(_pA0, _pB1, _sum1, vl); + + pA += 1; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + outptr0 += 8; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + } + + outptr += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum; + + if (k == 0) + { + _sum = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum = vle32_v_f32m1(pC, vl); + pC += 4; + } + } + } + else + { + _sum = vle32_v_f32m1(outptr, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pA = vfmv_v_f_f32m1(pA[0], vl); + + _sum = vfmadd_vv_f32m1(_pA, _pB, _sum, vl); + + pA += 1; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum, vl); + outptr0 += 4; + } + } + else + { + vse32_v_f32m1(outptr, _sum, vl); + } + + outptr += 4; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + float sum0; + float sum1; + + if (k == 0) + { + sum0 = 0.f; + sum1 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + sum0 = pC[0]; + sum1 = pC[1]; + pC += 2; + } + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + + pA += 1; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[1] = sum1; + outptr0 += 2; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + float sum; + + if (k == 0) + { + sum = 0.f; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum = pC[0]; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + sum = pC[0]; + pC += 1; + } + } + } + else + { + sum = outptr[0]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum; + outptr0++; + } + } + else + { + outptr[0] = sum; + } + + outptr += 1; + } + + pAT += max_kk; + } +} + +static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size = get_cpu_level2_cache_size(); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + int tile_size = (int)sqrtf((float)l2_cache_size / 3 / sizeof(float)); + + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); + TILE_K = std::max(8, tile_size / 8 * 8); + + if (K > 0) + { + int nn_K = (K + TILE_K - 1) / TILE_K; + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); + + if (nn_K == 1) + { + tile_size = (int)((float)l2_cache_size / 2 / sizeof(float) / TILE_K); + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); + } + } + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + if (M > 0) + { + int nn_M = (M + TILE_M - 1) / TILE_M; + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); + } + + if (N > 0) + { + int nn_N = (N + TILE_N - 1) / TILE_N; + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + } + + if (nT > 1) + { + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); + } + + // always take constant TILE_M/N/K value when provided + if (constant_TILE_M > 0) + { + TILE_M = (constant_TILE_M + 7) / 8 * 8; + } + + if (constant_TILE_N > 0) + { + TILE_N = (constant_TILE_N + 3) / 4 * 4; + } + + if (constant_TILE_K > 0) + { + TILE_K = (constant_TILE_K + 7) / 8 * 8; + } +} + +static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + + get_optimal_tile_mnk(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 4u, opt.workspace_allocator); + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + // pack B +#if TIME_TEST + gettimeofday(&start_time, NULL); +#endif + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + } + else + { + transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + } + } + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + } + else + { + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + } + } + } + + return 0; +} + +static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +{ + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + // pack B + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + } + else + { + transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + } + } + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + } + } + } + + return 0; +} + +static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + // int nn_N = (N + TILE_N - 1) / TILE_N; + + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 4u, opt.workspace_allocator); + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + } + else + { + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + } + } + } + + return 0; +} + +static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +{ + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + // int nn_N = (N + TILE_N - 1) / TILE_N; + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + } + } + } + + return 0; +} + +int Gemm_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + std::vector bottom_blobs(1, bottom_blob); + std::vector top_blobs(1, top_blob); + int ret = forward(bottom_blobs, top_blobs, opt); + top_blob = top_blobs[0]; + return ret; +} + +int Gemm_riscv::create_pipeline(const Option& opt) +{ + if (constantA) + { + const int M = constantM; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(M, 0, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + AT_data.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 4u, (Allocator*)0); + if (AT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT_data.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk, vl); + } + else + { + pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk, vl); + } + } + } + + if (opt.lightmode) + { + A_data.release(); + } + } + + if (constantB) + { + const int N = constantN; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(0, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_N = (N + TILE_N - 1) / TILE_N; + + BT_data.create(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, (Allocator*)0); + if (BT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_N; ppj++) + { + const int j = ppj * TILE_N; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT_data.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk, vl); + } + else + { + transpose_pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk, vl); + } + } + } + + if (opt.lightmode) + { + B_data.release(); + } + } + + if (constantC && constant_broadcast_type_C != -1) + { + CT_data = C_data; + +#if __riscv_vector + if (constant_broadcast_type_C == 3 && opt.use_packing_layout) + { + int C_elempack = constantM % 4 == 0 ? 4 : 1; + convert_packing(C_data, CT_data, C_elempack, opt); + } +#endif // __riscv_vector + + // pre-multiply C with beta + if (beta != 1.f) + { + Mat C2; + C2.create_like(CT_data); + + const int size = CT_data.total() * CT_data.elempack; + for (int i = 0; i < size; i++) + { + C2[i] = CT_data[i] * beta; + } + + CT_data = C2; + } + + if (opt.lightmode) + { + C_data.release(); + } + } + + if (constantA || constantB || constantC) + { + nT = opt.num_threads; + } + + return 0; +} + +int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int M; + int N; + if (constantA && constantB) + { + M = constantM; + N = constantN; + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + M = constantM; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = constantN; + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = CT_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB) + { + C = bottom_blobs.size() == 1 ? bottom_blobs[0] : Mat(); + } + else if (constantA) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else if (constantB) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else + { + C = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w * C.elempack == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w * C.elempack == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h * C.elempack == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == 1) + { + // 1xN + broadcast_type_C = 4; + } + + // pre-multiply C with beta + if (beta != 1.f) + { + Mat CT_data; + CT_data.create_like(C, opt.workspace_allocator); + + const int size = C.total() * C.elempack; + for (int i = 0; i < size; i++) + { + CT_data[i] = C[i] * beta; + } + + C = CT_data; + } + } + } + + int out_elempack = 1; +#if __riscv_vector + if (opt.use_packing_layout) + { + int outh = output_transpose ? N : M; + out_elempack = outh % 4 == 0 ? 4 : 1; + } +#endif // __riscv_vector + if (output_elempack) + out_elempack = output_elempack; + size_t out_elemsize = 4u * out_elempack; + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(M, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(N, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, gemm will use load-time value %d", opt.num_threads, nT); + } + + int ret = 0; + if (constantA && constantB) + { + ret = gemm_AT_BT_riscv(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + ret = gemm_AT_riscv(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + ret = gemm_BT_riscv(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + ret = gemm_riscv(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + } + if (ret != 0) + return ret; + + // multiply top_blob with alpha + if (alpha != 1.f) + { + const int size = top_blob.total() * out_elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < size; i++) + { + top_blob[i] *= alpha; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/riscv/gemm_riscv.h b/src/layer/riscv/gemm_riscv.h new file mode 100644 index 00000000000..b92add63891 --- /dev/null +++ b/src/layer/riscv/gemm_riscv.h @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_GEMM_RISCV_H +#define LAYER_GEMM_RISCV_H + +#include "gemm.h" + +namespace ncnn { + +class Gemm_riscv : virtual public Gemm +{ +public: + Gemm_riscv(); + + virtual int create_pipeline(const Option& opt); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + + // public: + int nT; + size_t vl; + Mat AT_data; + Mat BT_data; + Mat CT_data; +}; + +} // namespace ncnn + +#endif // LAYER_GEMM_RISCV_H diff --git a/src/layer/riscv/instancenorm_riscv.cpp b/src/layer/riscv/instancenorm_riscv.cpp index 95616866b8a..20cf5d94c7d 100644 --- a/src/layer/riscv/instancenorm_riscv.cpp +++ b/src/layer/riscv/instancenorm_riscv.cpp @@ -14,8 +14,6 @@ #include "instancenorm_riscv.h" -#include - #if __riscv_vector #include #endif // __riscv_vector diff --git a/src/layer/riscv/interp_riscv.cpp b/src/layer/riscv/interp_riscv.cpp index ea8344985ed..ac72cf9b63c 100644 --- a/src/layer/riscv/interp_riscv.cpp +++ b/src/layer/riscv/interp_riscv.cpp @@ -14,8 +14,6 @@ #include "interp_riscv.h" -#include - #if __riscv_vector #include #include "riscv_usability.h" diff --git a/src/layer/riscv/mish_riscv.cpp b/src/layer/riscv/mish_riscv.cpp index 4ddb1470006..57b17d3a732 100644 --- a/src/layer/riscv/mish_riscv.cpp +++ b/src/layer/riscv/mish_riscv.cpp @@ -20,8 +20,6 @@ #include "rvv_mathfun_fp16s.h" #endif // __riscv_vector -#include - namespace ncnn { Mish_riscv::Mish_riscv() diff --git a/src/layer/riscv/riscv_usability.h b/src/layer/riscv/riscv_usability.h index 596bf4435c6..938d3ce3998 100644 --- a/src/layer/riscv/riscv_usability.h +++ b/src/layer/riscv/riscv_usability.h @@ -86,6 +86,282 @@ static inline vfloat32m8_t vle32_v_f32m8_f32m1(const float* ptr) return vloxei32_v_f32m8(ptr, bindex, vl); } +static inline void transpose8x8_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0h, + vfloat32m1_t& _r1l, vfloat32m1_t& _r1h, + vfloat32m1_t& _r2l, vfloat32m1_t& _r2h, + vfloat32m1_t& _r3l, vfloat32m1_t& _r3h, + vfloat32m1_t& _r4l, vfloat32m1_t& _r4h, + vfloat32m1_t& _r5l, vfloat32m1_t& _r5h, + vfloat32m1_t& _r6l, vfloat32m1_t& _r6h, + vfloat32m1_t& _r7l, vfloat32m1_t& _r7h, size_t vl) +{ + float tmp[8][8]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 8, _r0l, vl); + vsse32_v_f32m1(&tmp[4][0], sizeof(float) * 8, _r0h, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 8, _r1l, vl); + vsse32_v_f32m1(&tmp[4][1], sizeof(float) * 8, _r1h, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 8, _r2l, vl); + vsse32_v_f32m1(&tmp[4][2], sizeof(float) * 8, _r2h, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 8, _r3l, vl); + vsse32_v_f32m1(&tmp[4][3], sizeof(float) * 8, _r3h, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 8, _r4l, vl); + vsse32_v_f32m1(&tmp[4][4], sizeof(float) * 8, _r4h, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 8, _r5l, vl); + vsse32_v_f32m1(&tmp[4][5], sizeof(float) * 8, _r5h, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 8, _r6l, vl); + vsse32_v_f32m1(&tmp[4][6], sizeof(float) * 8, _r6h, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 8, _r7l, vl); + vsse32_v_f32m1(&tmp[4][7], sizeof(float) * 8, _r7h, vl); + float* ptr = (float*)tmp; + _r0l = vle32_v_f32m1(ptr + 0 * 4, vl); + _r0h = vle32_v_f32m1(ptr + 1 * 4, vl); + _r1l = vle32_v_f32m1(ptr + 2 * 4, vl); + _r1h = vle32_v_f32m1(ptr + 3 * 4, vl); + _r2l = vle32_v_f32m1(ptr + 4 * 4, vl); + _r2h = vle32_v_f32m1(ptr + 5 * 4, vl); + _r3l = vle32_v_f32m1(ptr + 6 * 4, vl); + _r3h = vle32_v_f32m1(ptr + 7 * 4, vl); + _r4l = vle32_v_f32m1(ptr + 8 * 4, vl); + _r4h = vle32_v_f32m1(ptr + 9 * 4, vl); + _r5l = vle32_v_f32m1(ptr + 10 * 4, vl); + _r5h = vle32_v_f32m1(ptr + 11 * 4, vl); + _r6l = vle32_v_f32m1(ptr + 12 * 4, vl); + _r6h = vle32_v_f32m1(ptr + 13 * 4, vl); + _r7l = vle32_v_f32m1(ptr + 14 * 4, vl); + _r7h = vle32_v_f32m1(ptr + 15 * 4, vl); +} + +static inline void transpose4x4_ps(vfloat32m1_t& _r0, vfloat32m1_t& _r1, vfloat32m1_t& _r2, vfloat32m1_t& _r3, size_t vl) +{ + float tmp[4][4]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 4, _r0, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 4, _r1, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 4, _r2, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 4, _r3, vl); + float* ptr = (float*)tmp; + _r0 = vle32_v_f32m1(ptr + 0 * 4, vl); + _r1 = vle32_v_f32m1(ptr + 1 * 4, vl); + _r2 = vle32_v_f32m1(ptr + 2 * 4, vl); + _r3 = vle32_v_f32m1(ptr + 3 * 4, vl); +} + +static inline void transpose8x12_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0h, + vfloat32m1_t& _r1l, vfloat32m1_t& _r1h, + vfloat32m1_t& _r2l, vfloat32m1_t& _r2h, + vfloat32m1_t& _r3l, vfloat32m1_t& _r3h, + vfloat32m1_t& _r4l, vfloat32m1_t& _r4h, + vfloat32m1_t& _r5l, vfloat32m1_t& _r5h, + vfloat32m1_t& _r6l, vfloat32m1_t& _r6h, + vfloat32m1_t& _r7l, vfloat32m1_t& _r7h, + vfloat32m1_t& _r8l, vfloat32m1_t& _r8h, + vfloat32m1_t& _r9l, vfloat32m1_t& _r9h, + vfloat32m1_t& _ral, vfloat32m1_t& _rah, + vfloat32m1_t& _rbl, vfloat32m1_t& _rbh, size_t vl) +{ + float tmp[8][12]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 12, _r0l, vl); + vsse32_v_f32m1(&tmp[4][0], sizeof(float) * 12, _r0h, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 12, _r1l, vl); + vsse32_v_f32m1(&tmp[4][1], sizeof(float) * 12, _r1h, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 12, _r2l, vl); + vsse32_v_f32m1(&tmp[4][2], sizeof(float) * 12, _r2h, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 12, _r3l, vl); + vsse32_v_f32m1(&tmp[4][3], sizeof(float) * 12, _r3h, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 12, _r4l, vl); + vsse32_v_f32m1(&tmp[4][4], sizeof(float) * 12, _r4h, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 12, _r5l, vl); + vsse32_v_f32m1(&tmp[4][5], sizeof(float) * 12, _r5h, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 12, _r6l, vl); + vsse32_v_f32m1(&tmp[4][6], sizeof(float) * 12, _r6h, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 12, _r7l, vl); + vsse32_v_f32m1(&tmp[4][7], sizeof(float) * 12, _r7h, vl); + vsse32_v_f32m1(&tmp[0][8], sizeof(float) * 12, _r8l, vl); + vsse32_v_f32m1(&tmp[4][8], sizeof(float) * 12, _r8h, vl); + vsse32_v_f32m1(&tmp[0][9], sizeof(float) * 12, _r9l, vl); + vsse32_v_f32m1(&tmp[4][9], sizeof(float) * 12, _r9h, vl); + vsse32_v_f32m1(&tmp[0][10], sizeof(float) * 12, _ral, vl); + vsse32_v_f32m1(&tmp[4][10], sizeof(float) * 12, _rah, vl); + vsse32_v_f32m1(&tmp[0][11], sizeof(float) * 12, _rbl, vl); + vsse32_v_f32m1(&tmp[4][11], sizeof(float) * 12, _rbh, vl); + float* ptr = (float*)tmp; + _r0l = vle32_v_f32m1(ptr + 0 * 4, vl); + _r0h = vle32_v_f32m1(ptr + 1 * 4, vl); + _r1l = vle32_v_f32m1(ptr + 2 * 4, vl); + _r1h = vle32_v_f32m1(ptr + 3 * 4, vl); + _r2l = vle32_v_f32m1(ptr + 4 * 4, vl); + _r2h = vle32_v_f32m1(ptr + 5 * 4, vl); + _r3l = vle32_v_f32m1(ptr + 6 * 4, vl); + _r3h = vle32_v_f32m1(ptr + 7 * 4, vl); + _r4l = vle32_v_f32m1(ptr + 8 * 4, vl); + _r4h = vle32_v_f32m1(ptr + 9 * 4, vl); + _r5l = vle32_v_f32m1(ptr + 10 * 4, vl); + _r5h = vle32_v_f32m1(ptr + 11 * 4, vl); + _r6l = vle32_v_f32m1(ptr + 12 * 4, vl); + _r6h = vle32_v_f32m1(ptr + 13 * 4, vl); + _r7l = vle32_v_f32m1(ptr + 14 * 4, vl); + _r7h = vle32_v_f32m1(ptr + 15 * 4, vl); + _r8l = vle32_v_f32m1(ptr + 16 * 4, vl); + _r8h = vle32_v_f32m1(ptr + 17 * 4, vl); + _r9l = vle32_v_f32m1(ptr + 18 * 4, vl); + _r9h = vle32_v_f32m1(ptr + 19 * 4, vl); + _ral = vle32_v_f32m1(ptr + 20 * 4, vl); + _rah = vle32_v_f32m1(ptr + 21 * 4, vl); + _rbl = vle32_v_f32m1(ptr + 22 * 4, vl); + _rbh = vle32_v_f32m1(ptr + 23 * 4, vl); +} + +static inline void transpose12x8_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0m, vfloat32m1_t& _r0h, + vfloat32m1_t& _r1l, vfloat32m1_t& _r1m, vfloat32m1_t& _r1h, + vfloat32m1_t& _r2l, vfloat32m1_t& _r2m, vfloat32m1_t& _r2h, + vfloat32m1_t& _r3l, vfloat32m1_t& _r3m, vfloat32m1_t& _r3h, + vfloat32m1_t& _r4l, vfloat32m1_t& _r4m, vfloat32m1_t& _r4h, + vfloat32m1_t& _r5l, vfloat32m1_t& _r5m, vfloat32m1_t& _r5h, + vfloat32m1_t& _r6l, vfloat32m1_t& _r6m, vfloat32m1_t& _r6h, + vfloat32m1_t& _r7l, vfloat32m1_t& _r7m, vfloat32m1_t& _r7h, size_t vl) +{ + float tmp[12][8]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 8, _r0l, vl); + vsse32_v_f32m1(&tmp[4][0], sizeof(float) * 8, _r0m, vl); + vsse32_v_f32m1(&tmp[8][0], sizeof(float) * 8, _r0h, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 8, _r1l, vl); + vsse32_v_f32m1(&tmp[4][1], sizeof(float) * 8, _r1m, vl); + vsse32_v_f32m1(&tmp[8][0], sizeof(float) * 8, _r1h, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 8, _r2l, vl); + vsse32_v_f32m1(&tmp[4][2], sizeof(float) * 8, _r2m, vl); + vsse32_v_f32m1(&tmp[8][2], sizeof(float) * 8, _r2h, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 8, _r3l, vl); + vsse32_v_f32m1(&tmp[4][3], sizeof(float) * 8, _r3m, vl); + vsse32_v_f32m1(&tmp[8][3], sizeof(float) * 8, _r3h, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 8, _r4l, vl); + vsse32_v_f32m1(&tmp[4][4], sizeof(float) * 8, _r4m, vl); + vsse32_v_f32m1(&tmp[8][4], sizeof(float) * 8, _r4h, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 8, _r5l, vl); + vsse32_v_f32m1(&tmp[4][5], sizeof(float) * 8, _r5m, vl); + vsse32_v_f32m1(&tmp[8][5], sizeof(float) * 8, _r5h, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 8, _r6l, vl); + vsse32_v_f32m1(&tmp[4][6], sizeof(float) * 8, _r6m, vl); + vsse32_v_f32m1(&tmp[8][6], sizeof(float) * 8, _r6h, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 8, _r7l, vl); + vsse32_v_f32m1(&tmp[4][7], sizeof(float) * 8, _r7m, vl); + vsse32_v_f32m1(&tmp[8][7], sizeof(float) * 8, _r7h, vl); + float* ptr = (float*)tmp; + _r0l = vle32_v_f32m1(ptr + 0 * 4, vl); + _r0m = vle32_v_f32m1(ptr + 1 * 4, vl); + _r0h = vle32_v_f32m1(ptr + 2 * 4, vl); + _r1l = vle32_v_f32m1(ptr + 3 * 4, vl); + _r1m = vle32_v_f32m1(ptr + 4 * 4, vl); + _r1h = vle32_v_f32m1(ptr + 5 * 4, vl); + _r2l = vle32_v_f32m1(ptr + 6 * 4, vl); + _r2m = vle32_v_f32m1(ptr + 7 * 4, vl); + _r2h = vle32_v_f32m1(ptr + 8 * 4, vl); + _r3l = vle32_v_f32m1(ptr + 9 * 4, vl); + _r3m = vle32_v_f32m1(ptr + 10 * 4, vl); + _r3h = vle32_v_f32m1(ptr + 11 * 4, vl); + _r4l = vle32_v_f32m1(ptr + 12 * 4, vl); + _r4m = vle32_v_f32m1(ptr + 13 * 4, vl); + _r4h = vle32_v_f32m1(ptr + 14 * 4, vl); + _r5l = vle32_v_f32m1(ptr + 15 * 4, vl); + _r5m = vle32_v_f32m1(ptr + 16 * 4, vl); + _r5h = vle32_v_f32m1(ptr + 17 * 4, vl); + _r6l = vle32_v_f32m1(ptr + 18 * 4, vl); + _r6m = vle32_v_f32m1(ptr + 19 * 4, vl); + _r6h = vle32_v_f32m1(ptr + 20 * 4, vl); + _r7l = vle32_v_f32m1(ptr + 21 * 4, vl); + _r7m = vle32_v_f32m1(ptr + 22 * 4, vl); + _r7h = vle32_v_f32m1(ptr + 23 * 4, vl); +} + +static inline void transpose4x8_ps(vfloat32m1_t& _r0, vfloat32m1_t& _r1, vfloat32m1_t& _r2, vfloat32m1_t& _r3, vfloat32m1_t& _r4, vfloat32m1_t& _r5, vfloat32m1_t& _r6, vfloat32m1_t& _r7, size_t vl) +{ + float tmp[4][8]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 8, _r0, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 8, _r1, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 8, _r2, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 8, _r3, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 8, _r4, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 8, _r5, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 8, _r6, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 8, _r7, vl); + float* ptr = (float*)tmp; + _r0 = vle32_v_f32m1(ptr + 0 * 4, vl); + _r1 = vle32_v_f32m1(ptr + 1 * 4, vl); + _r2 = vle32_v_f32m1(ptr + 2 * 4, vl); + _r3 = vle32_v_f32m1(ptr + 3 * 4, vl); + _r4 = vle32_v_f32m1(ptr + 4 * 4, vl); + _r5 = vle32_v_f32m1(ptr + 5 * 4, vl); + _r6 = vle32_v_f32m1(ptr + 6 * 4, vl); + _r7 = vle32_v_f32m1(ptr + 7 * 4, vl); +} + +static inline void transpose4x12_ps(vfloat32m1_t& _r0, vfloat32m1_t& _r1, vfloat32m1_t& _r2, vfloat32m1_t& _r3, vfloat32m1_t& _r4, vfloat32m1_t& _r5, vfloat32m1_t& _r6, vfloat32m1_t& _r7, vfloat32m1_t& _r8, vfloat32m1_t& _r9, vfloat32m1_t& _ra, vfloat32m1_t& _rb, size_t vl) +{ + float tmp[4][12]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 12, _r0, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 12, _r1, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 12, _r2, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 12, _r3, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 12, _r4, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 12, _r5, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 12, _r6, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 12, _r7, vl); + vsse32_v_f32m1(&tmp[0][8], sizeof(float) * 12, _r8, vl); + vsse32_v_f32m1(&tmp[0][9], sizeof(float) * 12, _r9, vl); + vsse32_v_f32m1(&tmp[0][10], sizeof(float) * 12, _ra, vl); + vsse32_v_f32m1(&tmp[0][11], sizeof(float) * 12, _rb, vl); + float* ptr = (float*)tmp; + _r0 = vle32_v_f32m1(ptr + 0 * 4, vl); + _r1 = vle32_v_f32m1(ptr + 1 * 4, vl); + _r2 = vle32_v_f32m1(ptr + 2 * 4, vl); + _r3 = vle32_v_f32m1(ptr + 3 * 4, vl); + _r4 = vle32_v_f32m1(ptr + 4 * 4, vl); + _r5 = vle32_v_f32m1(ptr + 5 * 4, vl); + _r6 = vle32_v_f32m1(ptr + 6 * 4, vl); + _r7 = vle32_v_f32m1(ptr + 7 * 4, vl); + _r8 = vle32_v_f32m1(ptr + 8 * 4, vl); + _r9 = vle32_v_f32m1(ptr + 9 * 4, vl); + _ra = vle32_v_f32m1(ptr + 10 * 4, vl); + _rb = vle32_v_f32m1(ptr + 11 * 4, vl); +} + +static inline void transpose8x4_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0h, + vfloat32m1_t& _r1l, vfloat32m1_t& _r1h, + vfloat32m1_t& _r2l, vfloat32m1_t& _r2h, + vfloat32m1_t& _r3l, vfloat32m1_t& _r3h, size_t vl) +{ + float tmp[8][4]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 4, _r0l, vl); + vsse32_v_f32m1(&tmp[4][0], sizeof(float) * 4, _r0h, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 4, _r1l, vl); + vsse32_v_f32m1(&tmp[4][1], sizeof(float) * 4, _r1h, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 4, _r2l, vl); + vsse32_v_f32m1(&tmp[4][2], sizeof(float) * 4, _r2h, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 4, _r3l, vl); + vsse32_v_f32m1(&tmp[4][3], sizeof(float) * 4, _r3h, vl); + float* ptr = (float*)tmp; + _r0l = vle32_v_f32m1(ptr + 0 * 4, vl); + _r0h = vle32_v_f32m1(ptr + 1 * 4, vl); + _r1l = vle32_v_f32m1(ptr + 2 * 4, vl); + _r1h = vle32_v_f32m1(ptr + 3 * 4, vl); + _r2l = vle32_v_f32m1(ptr + 4 * 4, vl); + _r2h = vle32_v_f32m1(ptr + 5 * 4, vl); + _r3l = vle32_v_f32m1(ptr + 6 * 4, vl); + _r3h = vle32_v_f32m1(ptr + 7 * 4, vl); +} + +static inline void store_float_v2(vfloat32m1_t& vector1, vfloat32m1_t& vector2, float* buf, size_t vl) +{ + vsse32_v_f32m1(buf + 0, sizeof(float) * 2, vector1, vl); + vsse32_v_f32m1(buf + 1, sizeof(float) * 2, vector2, vl); +} + +static inline void store_float_v4(vfloat32m1_t& vector1, vfloat32m1_t& vector2, vfloat32m1_t& vector3, vfloat32m1_t& vector4, float* buf, size_t vl) +{ + vsse32_v_f32m1(buf + 0, sizeof(float) * 4, vector1, vl); + vsse32_v_f32m1(buf + 1, sizeof(float) * 4, vector2, vl); + vsse32_v_f32m1(buf + 2, sizeof(float) * 4, vector3, vl); + vsse32_v_f32m1(buf + 3, sizeof(float) * 4, vector4, vl); +} + #if __riscv_zfh static inline vfloat16m8_t vle16_v_f16m8_f16m1(const __fp16* ptr) { diff --git a/src/layer/riscv/rvv_mathfun.h b/src/layer/riscv/rvv_mathfun.h index ebf980060a7..34f072788e5 100644 --- a/src/layer/riscv/rvv_mathfun.h +++ b/src/layer/riscv/rvv_mathfun.h @@ -365,6 +365,19 @@ _RVV_FLOAT32_POW_OP(2, 16) _RVV_FLOAT32_POW_OP(4, 8) _RVV_FLOAT32_POW_OP(8, 4) +#if C906 +#define _RVV_FLOAT32_SIGMOID_OP(LMUL, MLEN) \ + static inline vfloat32m##LMUL##_t sigmoid_ps(vfloat32m##LMUL##_t _v, size_t vl) \ + { \ + _v = vfneg_v_f32m##LMUL(_v, vl); \ + _v = exp_ps(_v, vl); \ + _v = vfadd_vf_f32m##LMUL(_v, 1.f, vl); \ + vfloat32m##LMUL##_t _reciprocal = vfrdiv_vf_f32m##LMUL(_v, 1.f, vl); \ + _reciprocal = vfmul_vv_f32m##LMUL(vfrsub_vf_f32m##LMUL(vfmul_vv_f32m##LMUL(_v, _reciprocal, vl), 2.f, vl), _reciprocal, vl); \ + /* _reciprocal = vfmul_vv_f32m##LMUL(vfrsub_vf_f32m##LMUL(vfmul_vv_f32m##LMUL(_v, _reciprocal, vl), 2.f, vl), _reciprocal, vl); */ \ + return _reciprocal; \ + } +#else // C906 #define _RVV_FLOAT32_SIGMOID_OP(LMUL, MLEN) \ static inline vfloat32m##LMUL##_t sigmoid_ps(vfloat32m##LMUL##_t _v, size_t vl) \ { \ @@ -376,6 +389,7 @@ _RVV_FLOAT32_POW_OP(8, 4) /* _reciprocal = vfmul_vv_f32m##LMUL(vfrsub_vf_f32m##LMUL(vfmul_vv_f32m##LMUL(_v, _reciprocal, vl), 2.f, vl), _reciprocal, vl); */ \ return _reciprocal; \ } +#endif // C906 _RVV_FLOAT32_SIGMOID_OP(1, 32) _RVV_FLOAT32_SIGMOID_OP(2, 16) diff --git a/src/layer/riscv/rvv_mathfun_fp16s.h b/src/layer/riscv/rvv_mathfun_fp16s.h index 47671fe21f0..ee5ffe4a304 100644 --- a/src/layer/riscv/rvv_mathfun_fp16s.h +++ b/src/layer/riscv/rvv_mathfun_fp16s.h @@ -365,6 +365,19 @@ _RVV_FLOAT16_POW_OP(2, 8) _RVV_FLOAT16_POW_OP(4, 4) _RVV_FLOAT16_POW_OP(8, 2) +#if C906 +#define _RVV_FLOAT16_SIGMOID_OP(LMUL, MLEN) \ + static inline vfloat16m##LMUL##_t sigmoid_ps(vfloat16m##LMUL##_t _v, size_t vl) \ + { \ + _v = vfneg_v_f16m##LMUL(_v, vl); \ + _v = exp_ps(_v, vl); \ + _v = vfadd_vf_f16m##LMUL(_v, 1.f, vl); \ + vfloat16m##LMUL##_t _reciprocal = vfrdiv_vf_f16m##LMUL(_v, 1.f, vl); \ + _reciprocal = vfmul_vv_f16m##LMUL(vfrsub_vf_f16m##LMUL(vfmul_vv_f16m##LMUL(_v, _reciprocal, vl), 2.f, vl), _reciprocal, vl); \ + /* _reciprocal = vfmul_vv_f16m##LMUL(vfrsub_vf_f16m##LMUL(vfmul_vv_f16m##LMUL(_v, _reciprocal, vl), 2.f, vl), _reciprocal, vl); */ \ + return _reciprocal; \ + } +#else // C906 #define _RVV_FLOAT16_SIGMOID_OP(LMUL, MLEN) \ static inline vfloat16m##LMUL##_t sigmoid_ps(vfloat16m##LMUL##_t _v, size_t vl) \ { \ @@ -376,6 +389,7 @@ _RVV_FLOAT16_POW_OP(8, 2) /* _reciprocal = vfmul_vv_f16m##LMUL(vfrsub_vf_f16m##LMUL(vfmul_vv_f16m##LMUL(_v, _reciprocal, vl), 2.f, vl), _reciprocal, vl); */ \ return _reciprocal; \ } +#endif // C906 _RVV_FLOAT16_SIGMOID_OP(1, 16) _RVV_FLOAT16_SIGMOID_OP(2, 8) diff --git a/src/layer/riscv/sigmoid_riscv.cpp b/src/layer/riscv/sigmoid_riscv.cpp index 6c10582c668..14770f95e78 100644 --- a/src/layer/riscv/sigmoid_riscv.cpp +++ b/src/layer/riscv/sigmoid_riscv.cpp @@ -20,8 +20,6 @@ #include "rvv_mathfun_fp16s.h" #endif // __riscv_vector -#include - namespace ncnn { Sigmoid_riscv::Sigmoid_riscv() diff --git a/src/layer/riscv/swish_riscv.cpp b/src/layer/riscv/swish_riscv.cpp index 17493d7db69..7e2e2488c42 100644 --- a/src/layer/riscv/swish_riscv.cpp +++ b/src/layer/riscv/swish_riscv.cpp @@ -20,8 +20,6 @@ #include "rvv_mathfun_fp16s.h" #endif // __riscv_vector -#include - namespace ncnn { Swish_riscv::Swish_riscv() diff --git a/src/layer/riscv/tanh_riscv.cpp b/src/layer/riscv/tanh_riscv.cpp index d47de61dc59..0c147b15bd6 100644 --- a/src/layer/riscv/tanh_riscv.cpp +++ b/src/layer/riscv/tanh_riscv.cpp @@ -20,8 +20,6 @@ #include "rvv_mathfun_fp16s.h" #endif // __riscv_vector -#include - namespace ncnn { TanH_riscv::TanH_riscv() diff --git a/src/layer/riscv/unaryop_riscv.cpp b/src/layer/riscv/unaryop_riscv.cpp index 444312df1de..b6acf25e438 100644 --- a/src/layer/riscv/unaryop_riscv.cpp +++ b/src/layer/riscv/unaryop_riscv.cpp @@ -20,8 +20,6 @@ #include "rvv_mathfun_fp16s.h" #endif // __riscv_vector -#include - namespace ncnn { UnaryOp_riscv::UnaryOp_riscv() @@ -127,9 +125,13 @@ struct unary_op_rsqrt { vfloat32m8_t operator()(const vfloat32m8_t& x, const size_t& vl) const { +#if C906 + vfloat32m8_t _reciprocal = vfrdiv_vf_f32m8(vfsqrt_v_f32m8(x, vl), 1.f, vl); +#else vfloat32m8_t _reciprocal = vfrsqrt7_v_f32m8(x, vl); _reciprocal = vfmul_vv_f32m8(vfrsub_vf_f32m8(vfmul_vv_f32m8(vfmul_vf_f32m8(x, 0.5f, vl), vfmul_vv_f32m8(_reciprocal, _reciprocal, vl), vl), 1.5f, vl), _reciprocal, vl); // _reciprocal = vfmul_vv_f32m8(vfrsub_vf_f32m8(vfmul_vv_f32m8(vfmul_vf_f32m8(x, 0.5f, vl), vfmul_vv_f32m8(_reciprocal, _reciprocal, vl), vl), 1.5f, vl), _reciprocal, vl); +#endif return _reciprocal; } }; @@ -230,9 +232,13 @@ struct unary_op_reciprocal { vfloat32m8_t operator()(const vfloat32m8_t& x, const size_t& vl) const { +#if C906 + vfloat32m8_t _reciprocal = vfrdiv_vf_f32m8(x, 1.f, vl); +#else vfloat32m8_t _reciprocal = vfrec7_v_f32m8(x, vl); _reciprocal = vfmul_vv_f32m8(vfrsub_vf_f32m8(vfmul_vv_f32m8(x, _reciprocal, vl), 2.f, vl), _reciprocal, vl); // _reciprocal = vfmul_vv_f32m8(vfrsub_vf_f32m8(vfmul_vv_f32m8(x, _reciprocal, vl), 2.f, vl), _reciprocal, vl); +#endif return _reciprocal; } }; @@ -459,9 +465,13 @@ struct unary_op_rsqrt_fp16s { vfloat16m8_t operator()(const vfloat16m8_t& x, const size_t& vl) const { +#if C906 + vfloat16m8_t _reciprocal = vfrdiv_vf_f16m8(vfsqrt_v_f16m8(x, vl), 1.f, vl); +#else vfloat16m8_t _reciprocal = vfrsqrt7_v_f16m8(x, vl); _reciprocal = vfmul_vv_f16m8(vfrsub_vf_f16m8(vfmul_vv_f16m8(vfmul_vf_f16m8(x, 0.5f, vl), vfmul_vv_f16m8(_reciprocal, _reciprocal, vl), vl), 1.5f, vl), _reciprocal, vl); // _reciprocal = vfmul_vv_f16m8(vfrsub_vf_f16m8(vfmul_vv_f16m8(vfmul_vf_f16m8(x, 0.5f, vl), vfmul_vv_f16m8(_reciprocal, _reciprocal, vl), vl), 1.5f, vl), _reciprocal, vl); +#endif return _reciprocal; } }; @@ -562,9 +572,13 @@ struct unary_op_reciprocal_fp16s { vfloat16m8_t operator()(const vfloat16m8_t& x, const size_t& vl) const { +#if C906 + vfloat16m8_t _reciprocal = vfrdiv_vf_f16m8(x, 1.f, vl); +#else vfloat16m8_t _reciprocal = vfrec7_v_f16m8(x, vl); _reciprocal = vfmul_vv_f16m8(vfrsub_vf_f16m8(vfmul_vv_f16m8(x, _reciprocal, vl), 2.f, vl), _reciprocal, vl); // _reciprocal = vfmul_vv_f16m8(vfrsub_vf_f16m8(vfmul_vv_f16m8(x, _reciprocal, vl), 2.f, vl), _reciprocal, vl); +#endif return _reciprocal; } }; diff --git a/src/layer/rnn.cpp b/src/layer/rnn.cpp index d1856ce6fa9..6cc8ba5c9bd 100644 --- a/src/layer/rnn.cpp +++ b/src/layer/rnn.cpp @@ -14,8 +14,6 @@ #include "rnn.h" -#include - namespace ncnn { RNN::RNN() diff --git a/src/layer/roialign.cpp b/src/layer/roialign.cpp index 3d1c14538ce..a344f67f79d 100644 --- a/src/layer/roialign.cpp +++ b/src/layer/roialign.cpp @@ -15,7 +15,6 @@ #include "roialign.h" #include -#include namespace ncnn { diff --git a/src/layer/roipooling.cpp b/src/layer/roipooling.cpp index 96b43d3850f..9fd843737a3 100644 --- a/src/layer/roipooling.cpp +++ b/src/layer/roipooling.cpp @@ -14,8 +14,6 @@ #include "roipooling.h" -#include - namespace ncnn { ROIPooling::ROIPooling() diff --git a/src/layer/selu.cpp b/src/layer/selu.cpp index c3d1f962604..42a4ff2a813 100644 --- a/src/layer/selu.cpp +++ b/src/layer/selu.cpp @@ -14,8 +14,6 @@ #include "selu.h" -#include - namespace ncnn { SELU::SELU() @@ -36,8 +34,9 @@ int SELU::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { int w = bottom_top_blob.w; int h = bottom_top_blob.h; + int d = bottom_top_blob.d; int channels = bottom_top_blob.c; - int size = w * h; + int size = w * h * d; float alphaxlambda = alpha * lambda; #pragma omp parallel for num_threads(opt.num_threads) diff --git a/src/layer/sigmoid.cpp b/src/layer/sigmoid.cpp index 963c0f98f5a..4ed0dab5e81 100644 --- a/src/layer/sigmoid.cpp +++ b/src/layer/sigmoid.cpp @@ -14,8 +14,6 @@ #include "sigmoid.h" -#include - namespace ncnn { Sigmoid::Sigmoid() diff --git a/src/layer/softmax.cpp b/src/layer/softmax.cpp index a948f07f354..2768a82c20f 100644 --- a/src/layer/softmax.cpp +++ b/src/layer/softmax.cpp @@ -15,7 +15,6 @@ #include "softmax.h" #include -#include namespace ncnn { diff --git a/src/layer/softplus.cpp b/src/layer/softplus.cpp index 615496037c4..4910aad2949 100644 --- a/src/layer/softplus.cpp +++ b/src/layer/softplus.cpp @@ -14,8 +14,6 @@ #include "softplus.h" -#include - namespace ncnn { Softplus::Softplus() diff --git a/src/layer/spp.cpp b/src/layer/spp.cpp index a2678a32a8b..b7070955cb8 100644 --- a/src/layer/spp.cpp +++ b/src/layer/spp.cpp @@ -14,8 +14,6 @@ #include "spp.h" -#include - namespace ncnn { SPP::SPP() diff --git a/src/layer/statisticspooling.cpp b/src/layer/statisticspooling.cpp index 1947b61c875..9ed6d22f417 100644 --- a/src/layer/statisticspooling.cpp +++ b/src/layer/statisticspooling.cpp @@ -14,7 +14,6 @@ #include #include -#include namespace ncnn { diff --git a/src/layer/swish.cpp b/src/layer/swish.cpp index 3d8f3e3d65f..2816230c180 100644 --- a/src/layer/swish.cpp +++ b/src/layer/swish.cpp @@ -14,8 +14,6 @@ #include "swish.h" -#include - namespace ncnn { Swish::Swish() diff --git a/src/layer/tanh.cpp b/src/layer/tanh.cpp index a7d0249e1b9..c4b68352af6 100644 --- a/src/layer/tanh.cpp +++ b/src/layer/tanh.cpp @@ -14,8 +14,6 @@ #include "tanh.h" -#include - namespace ncnn { TanH::TanH() diff --git a/src/layer/unaryop.cpp b/src/layer/unaryop.cpp index 2fe77717ed3..b05add15cfb 100644 --- a/src/layer/unaryop.cpp +++ b/src/layer/unaryop.cpp @@ -14,9 +14,8 @@ #include "unaryop.h" -#include +// #include #include -#include namespace ncnn { diff --git a/src/layer/vulkan/binaryop_vulkan.cpp b/src/layer/vulkan/binaryop_vulkan.cpp index 3c0ad7299b5..37c0bb79e51 100644 --- a/src/layer/vulkan/binaryop_vulkan.cpp +++ b/src/layer/vulkan/binaryop_vulkan.cpp @@ -16,8 +16,6 @@ #include "layer_shader_type.h" -#include - namespace ncnn { BinaryOp_vulkan::BinaryOp_vulkan() diff --git a/src/layer/vulkan/convolution1d_vulkan.cpp b/src/layer/vulkan/convolution1d_vulkan.cpp new file mode 100644 index 00000000000..53dff49262b --- /dev/null +++ b/src/layer/vulkan/convolution1d_vulkan.cpp @@ -0,0 +1,401 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convolution1d_vulkan.h" + +#include "layer_shader_type.h" +#include "layer_type.h" + +namespace ncnn { + +Convolution1D_vulkan::Convolution1D_vulkan() +{ + support_vulkan = true; + support_image_storage = true; + + padding = 0; + + pipeline_convolution1d = 0; +} + +int Convolution1D_vulkan::create_pipeline(const Option& _opt) +{ + if (dynamic_weight) + { + support_vulkan = false; + support_image_storage = false; + return 0; + } + + Option opt = _opt; + + const int maxk = kernel_w; + int num_input = weight_data_size / maxk / num_output; + + int elempack = opt.use_shader_pack8 && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1; + int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; + + { + padding = ncnn::create_layer(ncnn::LayerType::Padding); + padding->vkdev = vkdev; + + ncnn::ParamDict pd; + pd.set(0, 0); + pd.set(1, 0); + pd.set(2, pad_left); + pd.set(3, pad_right); + pd.set(4, 0); + pd.set(5, pad_value); + + padding->load_param(pd); + + padding->create_pipeline(opt); + } + + { + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + + weight_data_packed.create(maxk, num_input / elempack, num_output / out_elempack, (size_t)4 * elempack * out_elempack, elempack * out_elempack); + + for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) + { + float* g00 = weight_data_packed.channel(q / out_elempack); + + for (int p = 0; p + (elempack - 1) < num_input; p += elempack) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < out_elempack; i++) + { + const Mat k0 = weight_data_r2.channel(q + i); + + for (int j = 0; j < elempack; j++) + { + const float* k00 = k0.row(p + j); + g00[0] = k00[k]; + g00++; + } + } + } + } + } + } + + if (bias_term) + { + convert_packing(bias_data, bias_data_packed, out_elempack, opt); + } + + { + std::vector specializations(7 + 4); + specializations[0].i = kernel_w; + specializations[1].i = dilation_w; + specializations[2].i = stride_w; + specializations[3].i = bias_term; + specializations[4].i = activation_type; + specializations[5].f = activation_params.w >= 1 ? activation_params[0] : 0.f; + specializations[6].f = activation_params.w == 2 ? activation_params[1] : 0.f; + specializations[7 + 0].i = 0; + specializations[7 + 1].i = 0; + specializations[7 + 2].i = 0; + specializations[7 + 3].i = 0; + + int shader_type_index = -1; + if (elempack == 1 && out_elempack == 1) shader_type_index = LayerShaderType::convolution1d; + if (elempack == 4 && out_elempack == 4) shader_type_index = LayerShaderType::convolution1d_pack4; + if (elempack == 1 && out_elempack == 4) shader_type_index = LayerShaderType::convolution1d_pack1to4; + if (elempack == 4 && out_elempack == 1) shader_type_index = LayerShaderType::convolution1d_pack4to1; + if (elempack == 8 && out_elempack == 8) shader_type_index = LayerShaderType::convolution1d_pack8; + if (elempack == 1 && out_elempack == 8) shader_type_index = LayerShaderType::convolution1d_pack1to8; + if (elempack == 8 && out_elempack == 1) shader_type_index = LayerShaderType::convolution1d_pack8to1; + if (elempack == 4 && out_elempack == 8) shader_type_index = LayerShaderType::convolution1d_pack4to8; + if (elempack == 8 && out_elempack == 4) shader_type_index = LayerShaderType::convolution1d_pack8to4; + + pipeline_convolution1d = new Pipeline(vkdev); + pipeline_convolution1d->set_optimal_local_size_xyz(1, 1, 1); + pipeline_convolution1d->create(shader_type_index, opt, specializations); + } + + return 0; +} + +int Convolution1D_vulkan::destroy_pipeline(const Option& opt) +{ + if (padding) + { + padding->destroy_pipeline(opt); + delete padding; + padding = 0; + } + + delete pipeline_convolution1d; + pipeline_convolution1d = 0; + + return 0; +} + +int Convolution1D_vulkan::upload_model(VkTransfer& cmd, const Option& opt) +{ + if (padding) + { + padding->upload_model(cmd, opt); + } + + if (support_image_storage && opt.use_image_storage) + { + cmd.record_upload(weight_data_packed, weight_data_gpu_image, opt); + } + else + { + cmd.record_upload(weight_data_packed, weight_data_gpu, opt); + } + + weight_data_packed.release(); + + if (bias_term) + { + if (support_image_storage && opt.use_image_storage) + { + cmd.record_upload(bias_data_packed, bias_data_gpu_image, opt); + } + else + { + cmd.record_upload(bias_data_packed, bias_data_gpu, opt); + } + + bias_data_packed.release(); + } + + return 0; +} + +int Convolution1D_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const +{ + int w = bottom_blob.w; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + + VkMat bottom_blob_bordered = bottom_blob; + if (pad_left > 0 || pad_right > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + padding->forward(bottom_blob, bottom_blob_bordered, cmd, opt_pad); + } + else if (pad_left == -233 && pad_right == -233) + { + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + if (wpad > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + VkMat padding_param_blob(6, (size_t)4u, 1, opt.staging_vkallocator); + int* padding_params = padding_param_blob.mapped(); + + padding_params[0] = 0; + padding_params[1] = 0; + padding_params[2] = wpad / 2; + padding_params[3] = wpad - wpad / 2; + padding_params[4] = 0; + padding_params[5] = 0; + std::vector padding_inputs(2); + padding_inputs[0] = bottom_blob; + padding_inputs[1] = padding_param_blob; + + std::vector padding_outputs(1); + padding->forward(padding_inputs, padding_outputs, cmd, opt_pad); + bottom_blob_bordered = padding_outputs[0]; + } + } + else if (pad_left == -234 && pad_right == -234) + { + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + if (wpad > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + VkMat padding_param_blob(6, (size_t)4u, 1, opt.staging_vkallocator); + int* padding_params = padding_param_blob.mapped(); + + padding_params[0] = 0; + padding_params[1] = 0; + padding_params[2] = wpad - wpad / 2; + padding_params[3] = wpad / 2; + padding_params[4] = 0; + padding_params[5] = 0; + + std::vector padding_inputs(2); + padding_inputs[0] = bottom_blob; + padding_inputs[1] = padding_param_blob; + + std::vector padding_outputs(1); + padding->forward(padding_inputs, padding_outputs, cmd, opt_pad); + bottom_blob_bordered = padding_outputs[0]; + } + } + + int outw = (bottom_blob_bordered.w - kernel_extent_w) / stride_w + 1; + + int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; + + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (opt.use_fp16_packed && !opt.use_fp16_storage) + { + if (out_elempack == 8) out_elemsize = 8 * 2u; + if (out_elempack == 4) out_elemsize = 4 * 2u; + if (out_elempack == 1) out_elemsize = 4u; + } + + top_blob.create(outw, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator); + if (top_blob.empty()) + return -100; + + std::vector bindings(4); + bindings[0] = bottom_blob_bordered; + bindings[1] = top_blob; + bindings[2] = weight_data_gpu; + bindings[3] = bias_data_gpu; + + std::vector constants(4); + constants[0].i = bottom_blob_bordered.w; + constants[1].i = bottom_blob_bordered.h; + constants[2].i = top_blob.w; + constants[3].i = top_blob.h; + + VkMat dispatcher; + dispatcher.w = (top_blob.w + 1) / 2; + dispatcher.h = (top_blob.h + 1) / 2; + dispatcher.c = 1; + + cmd.record_pipeline(pipeline_convolution1d, bindings, constants, dispatcher); + + return 0; +} + +int Convolution1D_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const +{ + int w = bottom_blob.w; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + + VkImageMat bottom_blob_bordered = bottom_blob; + if (pad_left > 0 || pad_right > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + padding->forward(bottom_blob, bottom_blob_bordered, cmd, opt_pad); + } + else if (pad_left == -233 && pad_right == -233) + { + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + if (wpad > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + VkImageMat padding_param_blob(6, (size_t)4u, 1, opt.staging_vkallocator); + int* padding_params = padding_param_blob.mapped(); + + padding_params[0] = 0; + padding_params[1] = 0; + padding_params[2] = wpad / 2; + padding_params[3] = wpad - wpad / 2; + padding_params[4] = 0; + padding_params[5] = 0; + std::vector padding_inputs(2); + padding_inputs[0] = bottom_blob; + padding_inputs[1] = padding_param_blob; + + std::vector padding_outputs(1); + padding->forward(padding_inputs, padding_outputs, cmd, opt_pad); + bottom_blob_bordered = padding_outputs[0]; + } + } + else if (pad_left == -234 && pad_right == -234) + { + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + if (wpad > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + VkImageMat padding_param_blob(6, (size_t)4u, 1, opt.staging_vkallocator); + int* padding_params = padding_param_blob.mapped(); + + padding_params[0] = 0; + padding_params[1] = 0; + padding_params[2] = wpad - wpad / 2; + padding_params[3] = wpad / 2; + padding_params[4] = 0; + padding_params[5] = 0; + + std::vector padding_inputs(2); + padding_inputs[0] = bottom_blob; + padding_inputs[1] = padding_param_blob; + + std::vector padding_outputs(1); + padding->forward(padding_inputs, padding_outputs, cmd, opt_pad); + bottom_blob_bordered = padding_outputs[0]; + } + } + + int outw = (bottom_blob_bordered.w - kernel_extent_w) / stride_w + 1; + + int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; + + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (opt.use_fp16_packed && !opt.use_fp16_storage) + { + if (out_elempack == 8) out_elemsize = 8 * 2u; + if (out_elempack == 4) out_elemsize = 4 * 2u; + if (out_elempack == 1) out_elemsize = 4u; + } + + top_blob.create(outw, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator); + if (top_blob.empty()) + return -100; + + std::vector bindings(4); + bindings[0] = bottom_blob_bordered; + bindings[1] = top_blob; + bindings[2] = weight_data_gpu_image; + bindings[3] = bias_data_gpu_image; + + std::vector constants(4); + constants[0].i = bottom_blob_bordered.w; + constants[1].i = bottom_blob_bordered.h; + constants[2].i = top_blob.w; + constants[3].i = top_blob.h; + + VkImageMat dispatcher; + dispatcher.w = (top_blob.w + 1) / 2; + dispatcher.h = (top_blob.h + 1) / 2; + dispatcher.c = 1; + + cmd.record_pipeline(pipeline_convolution1d, bindings, constants, dispatcher); + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/vulkan/convolution1d_vulkan.h b/src/layer/vulkan/convolution1d_vulkan.h new file mode 100644 index 00000000000..4fb22040daa --- /dev/null +++ b/src/layer/vulkan/convolution1d_vulkan.h @@ -0,0 +1,53 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_CONVOLUTION1D_VULKAN_H +#define LAYER_CONVOLUTION1D_VULKAN_H + +#include "convolution1d.h" + +namespace ncnn { + +class Convolution1D_vulkan : virtual public Convolution1D +{ +public: + Convolution1D_vulkan(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + virtual int upload_model(VkTransfer& cmd, const Option& opt); + + using Convolution1D::forward; + virtual int forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const; + virtual int forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const; + +public: + ncnn::Layer* padding; + + Mat weight_data_packed; + Mat bias_data_packed; + + VkMat weight_data_gpu; + VkMat bias_data_gpu; + + VkImageMat weight_data_gpu_image; + VkImageMat bias_data_gpu_image; + + Pipeline* pipeline_convolution1d; +}; + +} // namespace ncnn + +#endif // LAYER_CONVOLUTION1D_VULKAN_H diff --git a/src/layer/vulkan/priorbox_vulkan.cpp b/src/layer/vulkan/priorbox_vulkan.cpp index ba41fc96e59..5cfe341cd78 100644 --- a/src/layer/vulkan/priorbox_vulkan.cpp +++ b/src/layer/vulkan/priorbox_vulkan.cpp @@ -17,8 +17,6 @@ #include "layer_shader_type.h" #include "platform.h" -#include - namespace ncnn { PriorBox_vulkan::PriorBox_vulkan() diff --git a/src/layer/vulkan/shader/convolution1d.comp b/src/layer/vulkan/shader/convolution1d.comp new file mode 100644 index 00000000000..a23f8b1abed --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d.comp @@ -0,0 +1,165 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfp weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfp bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afp sum0 = afp(0.0f); + afp sum1 = afp(0.0f); + afp sum2 = afp(0.0f); + afp sum3 = afp(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld1(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld1(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld1(bias_data, gy2.x); + sum2 = buffer_ld1(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afp v0 = image3d_ld1(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afp v1 = image3d_ld1(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afp k0 = image3d_ld1(weight_blob, ivec3(wx, y, gy2.x)); + afp k1 = image3d_ld1(weight_blob, ivec3(wx, y, gy2.y)); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afp v0 = buffer_ld1(bottom_blob_data, v_offset.x + x * dilation_w); + afp v1 = buffer_ld1(bottom_blob_data, v_offset.y + x * dilation_w); + + afp k0 = buffer_ld1(weight_data, w_offset.x + x); + afp k1 = buffer_ld1(weight_data, w_offset.y + x); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afp(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afp(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afp(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afp(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st1(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st1(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st1(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st1(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st1(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st1(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/vulkan/shader/convolution1d_pack1to4.comp b/src/layer/vulkan/shader/convolution1d_pack1to4.comp new file mode 100644 index 00000000000..bcdd1778221 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack1to4.comp @@ -0,0 +1,165 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec4 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec4 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec4 sum0 = afpvec4(0.0f); + afpvec4 sum1 = afpvec4(0.0f); + afpvec4 sum2 = afpvec4(0.0f); + afpvec4 sum3 = afpvec4(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld4(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld4(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld4(bias_data, gy2.x); + sum2 = buffer_ld4(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afp v0 = image3d_ld1(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afp v1 = image3d_ld1(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec4 k0 = image3d_ld4(weight_blob, ivec3(wx, y, gy2.x)); + afpvec4 k1 = image3d_ld4(weight_blob, ivec3(wx, y, gy2.y)); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afp v0 = buffer_ld1(bottom_blob_data, v_offset.x + x * dilation_w); + afp v1 = buffer_ld1(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec4 k0 = buffer_ld4(weight_data, w_offset.x + x); + afpvec4 k1 = buffer_ld4(weight_data, w_offset.y + x); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec4(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec4(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st4(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st4(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st4(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st4(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st4(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st4(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/vulkan/shader/convolution1d_pack1to8.comp b/src/layer/vulkan/shader/convolution1d_pack1to8.comp new file mode 100644 index 00000000000..d91559ecd1c --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack1to8.comp @@ -0,0 +1,174 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec8 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec8 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec8 sum0 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum1 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum2 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum3 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld8(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld8(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld8(bias_data, gy2.x); + sum2 = buffer_ld8(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afp v0 = image3d_ld1(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afp v1 = image3d_ld1(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec8 k0 = image3d_ld8(weight_blob, ivec3(wx, y, gy2.x)); + afpvec8 k1 = image3d_ld8(weight_blob, ivec3(wx, y, gy2.y)); + + sum0[0] += v0 * k0[0]; + sum0[1] += v0 * k0[1]; + sum1[0] += v1 * k0[0]; + sum1[1] += v1 * k0[1]; + sum2[0] += v0 * k1[0]; + sum2[1] += v0 * k1[1]; + sum3[0] += v1 * k1[0]; + sum3[1] += v1 * k1[1]; + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afp v0 = buffer_ld1(bottom_blob_data, v_offset.x + x * dilation_w); + afp v1 = buffer_ld1(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec8 k0 = buffer_ld8(weight_data, w_offset.x + x); + afpvec8 k1 = buffer_ld8(weight_data, w_offset.y + x); + + sum0[0] += v0 * k0[0]; + sum0[1] += v0 * k0[1]; + sum1[0] += v1 * k0[0]; + sum1[1] += v1 * k0[1]; + sum2[0] += v0 * k1[0]; + sum2[1] += v0 * k1[1]; + sum3[0] += v1 * k1[0]; + sum3[1] += v1 * k1[1]; + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec8(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec8(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec8(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec8(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st8(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st8(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st8(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st8(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st8(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st8(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/vulkan/shader/convolution1d_pack4.comp b/src/layer/vulkan/shader/convolution1d_pack4.comp new file mode 100644 index 00000000000..7ce8ddb013d --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack4.comp @@ -0,0 +1,196 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec4 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; +#if NCNN_fp16_packed || (NCNN_fp16_storage && !NCNN_fp16_arithmetic) +// GL_EXT_shader_16bit_storage does not define f16mat4 type :( +layout (binding = 2) readonly buffer weight_blob { sfpvec4 weight_data[]; }; +#else +layout (binding = 2) readonly buffer weight_blob { sfpmat4 weight_data[]; }; +#endif +layout (binding = 3) readonly buffer bias_blob { sfpvec4 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec4 sum0 = afpvec4(0.0f); + afpvec4 sum1 = afpvec4(0.0f); + afpvec4 sum2 = afpvec4(0.0f); + afpvec4 sum3 = afpvec4(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld4(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld4(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld4(bias_data, gy2.x); + sum2 = buffer_ld4(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpmat4 k0 = afpmat4( + image3d_ld4(weight_blob, ivec3(wx + 0, y, gy2.x)), + image3d_ld4(weight_blob, ivec3(wx + 1, y, gy2.x)), + image3d_ld4(weight_blob, ivec3(wx + 2, y, gy2.x)), + image3d_ld4(weight_blob, ivec3(wx + 3, y, gy2.x)) + ); + afpmat4 k1 = afpmat4( + image3d_ld4(weight_blob, ivec3(wx + 0, y, gy2.y)), + image3d_ld4(weight_blob, ivec3(wx + 1, y, gy2.y)), + image3d_ld4(weight_blob, ivec3(wx + 2, y, gy2.y)), + image3d_ld4(weight_blob, ivec3(wx + 3, y, gy2.y)) + ); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + + wx += 4; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = buffer_ld4(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec4 v1 = buffer_ld4(bottom_blob_data, v_offset.y + x * dilation_w); + +#if NCNN_fp16_packed || (NCNN_fp16_storage && !NCNN_fp16_arithmetic) + // GL_EXT_shader_16bit_storage does not define f16mat4 type :( + afpmat4 k0 = afpmat4( + buffer_ld4(weight_data, (w_offset.x + x) * 4 + 0), + buffer_ld4(weight_data, (w_offset.x + x) * 4 + 1), + buffer_ld4(weight_data, (w_offset.x + x) * 4 + 2), + buffer_ld4(weight_data, (w_offset.x + x) * 4 + 3) + ); + afpmat4 k1 = afpmat4( + buffer_ld4(weight_data, (w_offset.y + x) * 4 + 0), + buffer_ld4(weight_data, (w_offset.y + x) * 4 + 1), + buffer_ld4(weight_data, (w_offset.y + x) * 4 + 2), + buffer_ld4(weight_data, (w_offset.y + x) * 4 + 3) + ); +#else + afpmat4 k0 = sfp2afpmat4(weight_data[w_offset.x + x]); + afpmat4 k1 = sfp2afpmat4(weight_data[w_offset.y + x]); +#endif + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec4(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec4(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st4(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st4(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st4(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st4(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st4(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st4(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/vulkan/shader/convolution1d_pack4to1.comp b/src/layer/vulkan/shader/convolution1d_pack4to1.comp new file mode 100644 index 00000000000..b262fd689cc --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack4to1.comp @@ -0,0 +1,165 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec4 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec4 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfp bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afp sum0 = afp(0.0f); + afp sum1 = afp(0.0f); + afp sum2 = afp(0.0f); + afp sum3 = afp(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld1(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld1(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld1(bias_data, gy2.x); + sum2 = buffer_ld1(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec4 k0 = image3d_ld4(weight_blob, ivec3(wx, y, gy2.x)); + afpvec4 k1 = image3d_ld4(weight_blob, ivec3(wx, y, gy2.y)); + + sum0 += dot(v0, k0); + sum1 += dot(v1, k0); + sum2 += dot(v0, k1); + sum3 += dot(v1, k1); + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = buffer_ld4(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec4 v1 = buffer_ld4(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec4 k0 = buffer_ld4(weight_data, w_offset.x + x); + afpvec4 k1 = buffer_ld4(weight_data, w_offset.y + x); + + sum0 += dot(v0, k0); + sum1 += dot(v1, k0); + sum2 += dot(v0, k1); + sum3 += dot(v1, k1); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afp(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afp(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afp(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afp(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st1(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st1(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st1(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st1(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st1(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st1(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/vulkan/shader/convolution1d_pack4to8.comp b/src/layer/vulkan/shader/convolution1d_pack4to8.comp new file mode 100644 index 00000000000..111972991d4 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack4to8.comp @@ -0,0 +1,258 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec4 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec4 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec8 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec8 sum0 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum1 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum2 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum3 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld8(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld8(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld8(bias_data, gy2.x); + sum2 = buffer_ld8(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec4 k0 = image3d_ld4(weight_blob, ivec3(wx + 0, y, gy2.x)); + afpvec4 k1 = image3d_ld4(weight_blob, ivec3(wx + 1, y, gy2.x)); + afpvec4 k2 = image3d_ld4(weight_blob, ivec3(wx + 2, y, gy2.x)); + afpvec4 k3 = image3d_ld4(weight_blob, ivec3(wx + 3, y, gy2.x)); + afpvec4 k4 = image3d_ld4(weight_blob, ivec3(wx + 4, y, gy2.x)); + afpvec4 k5 = image3d_ld4(weight_blob, ivec3(wx + 5, y, gy2.x)); + afpvec4 k6 = image3d_ld4(weight_blob, ivec3(wx + 6, y, gy2.x)); + afpvec4 k7 = image3d_ld4(weight_blob, ivec3(wx + 7, y, gy2.x)); + + afpvec4 k8 = image3d_ld4(weight_blob, ivec3(wx + 0, y, gy2.y)); + afpvec4 k9 = image3d_ld4(weight_blob, ivec3(wx + 1, y, gy2.y)); + afpvec4 ka = image3d_ld4(weight_blob, ivec3(wx + 2, y, gy2.y)); + afpvec4 kb = image3d_ld4(weight_blob, ivec3(wx + 3, y, gy2.y)); + afpvec4 kc = image3d_ld4(weight_blob, ivec3(wx + 4, y, gy2.y)); + afpvec4 kd = image3d_ld4(weight_blob, ivec3(wx + 5, y, gy2.y)); + afpvec4 ke = image3d_ld4(weight_blob, ivec3(wx + 6, y, gy2.y)); + afpvec4 kf = image3d_ld4(weight_blob, ivec3(wx + 7, y, gy2.y)); + + sum0[0].r += dot(v0, k0); + sum0[0].g += dot(v0, k1); + sum0[0].b += dot(v0, k2); + sum0[0].a += dot(v0, k3); + sum0[1].r += dot(v0, k4); + sum0[1].g += dot(v0, k5); + sum0[1].b += dot(v0, k6); + sum0[1].a += dot(v0, k7); + + sum1[0].r += dot(v1, k0); + sum1[0].g += dot(v1, k1); + sum1[0].b += dot(v1, k2); + sum1[0].a += dot(v1, k3); + sum1[1].r += dot(v1, k4); + sum1[1].g += dot(v1, k5); + sum1[1].b += dot(v1, k6); + sum1[1].a += dot(v1, k7); + + sum2[0].r += dot(v0, k8); + sum2[0].g += dot(v0, k9); + sum2[0].b += dot(v0, ka); + sum2[0].a += dot(v0, kb); + sum2[1].r += dot(v0, kc); + sum2[1].g += dot(v0, kd); + sum2[1].b += dot(v0, ke); + sum2[1].a += dot(v0, kf); + + sum3[0].r += dot(v1, k8); + sum3[0].g += dot(v1, k9); + sum3[0].b += dot(v1, ka); + sum3[0].a += dot(v1, kb); + sum3[1].r += dot(v1, kc); + sum3[1].g += dot(v1, kd); + sum3[1].b += dot(v1, ke); + sum3[1].a += dot(v1, kf); + + wx += 8; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = buffer_ld4(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec4 v1 = buffer_ld4(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec4 k0 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 0); + afpvec4 k1 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 1); + afpvec4 k2 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 2); + afpvec4 k3 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 3); + afpvec4 k4 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 4); + afpvec4 k5 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 5); + afpvec4 k6 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 6); + afpvec4 k7 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 7); + + afpvec4 k8 = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 0); + afpvec4 k9 = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 1); + afpvec4 ka = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 2); + afpvec4 kb = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 3); + afpvec4 kc = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 4); + afpvec4 kd = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 5); + afpvec4 ke = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 6); + afpvec4 kf = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 7); + + sum0[0].r += dot(v0, k0); + sum0[0].g += dot(v0, k1); + sum0[0].b += dot(v0, k2); + sum0[0].a += dot(v0, k3); + sum0[1].r += dot(v0, k4); + sum0[1].g += dot(v0, k5); + sum0[1].b += dot(v0, k6); + sum0[1].a += dot(v0, k7); + + sum1[0].r += dot(v1, k0); + sum1[0].g += dot(v1, k1); + sum1[0].b += dot(v1, k2); + sum1[0].a += dot(v1, k3); + sum1[1].r += dot(v1, k4); + sum1[1].g += dot(v1, k5); + sum1[1].b += dot(v1, k6); + sum1[1].a += dot(v1, k7); + + sum2[0].r += dot(v0, k8); + sum2[0].g += dot(v0, k9); + sum2[0].b += dot(v0, ka); + sum2[0].a += dot(v0, kb); + sum2[1].r += dot(v0, kc); + sum2[1].g += dot(v0, kd); + sum2[1].b += dot(v0, ke); + sum2[1].a += dot(v0, kf); + + sum3[0].r += dot(v1, k8); + sum3[0].g += dot(v1, k9); + sum3[0].b += dot(v1, ka); + sum3[0].a += dot(v1, kb); + sum3[1].r += dot(v1, kc); + sum3[1].g += dot(v1, kd); + sum3[1].b += dot(v1, ke); + sum3[1].a += dot(v1, kf); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec8(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec8(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec8(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec8(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st8(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st8(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st8(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st8(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st8(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st8(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/vulkan/shader/convolution1d_pack8.comp b/src/layer/vulkan/shader/convolution1d_pack8.comp new file mode 100644 index 00000000000..38271697a48 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack8.comp @@ -0,0 +1,258 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec8 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec8 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec8 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec8 sum0 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum1 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum2 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum3 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld8(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld8(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld8(bias_data, gy2.x); + sum2 = buffer_ld8(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = image3d_ld8(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec8 v1 = image3d_ld8(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec8 k0 = image3d_ld8(weight_blob, ivec3(wx + 0, y, gy2.x)); + afpvec8 k1 = image3d_ld8(weight_blob, ivec3(wx + 1, y, gy2.x)); + afpvec8 k2 = image3d_ld8(weight_blob, ivec3(wx + 2, y, gy2.x)); + afpvec8 k3 = image3d_ld8(weight_blob, ivec3(wx + 3, y, gy2.x)); + afpvec8 k4 = image3d_ld8(weight_blob, ivec3(wx + 4, y, gy2.x)); + afpvec8 k5 = image3d_ld8(weight_blob, ivec3(wx + 5, y, gy2.x)); + afpvec8 k6 = image3d_ld8(weight_blob, ivec3(wx + 6, y, gy2.x)); + afpvec8 k7 = image3d_ld8(weight_blob, ivec3(wx + 7, y, gy2.x)); + + afpvec8 k8 = image3d_ld8(weight_blob, ivec3(wx + 0, y, gy2.y)); + afpvec8 k9 = image3d_ld8(weight_blob, ivec3(wx + 1, y, gy2.y)); + afpvec8 ka = image3d_ld8(weight_blob, ivec3(wx + 2, y, gy2.y)); + afpvec8 kb = image3d_ld8(weight_blob, ivec3(wx + 3, y, gy2.y)); + afpvec8 kc = image3d_ld8(weight_blob, ivec3(wx + 4, y, gy2.y)); + afpvec8 kd = image3d_ld8(weight_blob, ivec3(wx + 5, y, gy2.y)); + afpvec8 ke = image3d_ld8(weight_blob, ivec3(wx + 6, y, gy2.y)); + afpvec8 kf = image3d_ld8(weight_blob, ivec3(wx + 7, y, gy2.y)); + + sum0[0].r += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum0[0].g += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum0[0].b += dot(v0[0], k2[0]) + dot(v0[1], k2[1]); + sum0[0].a += dot(v0[0], k3[0]) + dot(v0[1], k3[1]); + sum0[1].r += dot(v0[0], k4[0]) + dot(v0[1], k4[1]); + sum0[1].g += dot(v0[0], k5[0]) + dot(v0[1], k5[1]); + sum0[1].b += dot(v0[0], k6[0]) + dot(v0[1], k6[1]); + sum0[1].a += dot(v0[0], k7[0]) + dot(v0[1], k7[1]); + + sum1[0].r += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum1[0].g += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + sum1[0].b += dot(v1[0], k2[0]) + dot(v1[1], k2[1]); + sum1[0].a += dot(v1[0], k3[0]) + dot(v1[1], k3[1]); + sum1[1].r += dot(v1[0], k4[0]) + dot(v1[1], k4[1]); + sum1[1].g += dot(v1[0], k5[0]) + dot(v1[1], k5[1]); + sum1[1].b += dot(v1[0], k6[0]) + dot(v1[1], k6[1]); + sum1[1].a += dot(v1[0], k7[0]) + dot(v1[1], k7[1]); + + sum2[0].r += dot(v0[0], k8[0]) + dot(v0[1], k8[1]); + sum2[0].g += dot(v0[0], k9[0]) + dot(v0[1], k9[1]); + sum2[0].b += dot(v0[0], ka[0]) + dot(v0[1], ka[1]); + sum2[0].a += dot(v0[0], kb[0]) + dot(v0[1], kb[1]); + sum2[1].r += dot(v0[0], kc[0]) + dot(v0[1], kc[1]); + sum2[1].g += dot(v0[0], kd[0]) + dot(v0[1], kd[1]); + sum2[1].b += dot(v0[0], ke[0]) + dot(v0[1], ke[1]); + sum2[1].a += dot(v0[0], kf[0]) + dot(v0[1], kf[1]); + + sum3[0].r += dot(v1[0], k8[0]) + dot(v1[1], k8[1]); + sum3[0].g += dot(v1[0], k9[0]) + dot(v1[1], k9[1]); + sum3[0].b += dot(v1[0], ka[0]) + dot(v1[1], ka[1]); + sum3[0].a += dot(v1[0], kb[0]) + dot(v1[1], kb[1]); + sum3[1].r += dot(v1[0], kc[0]) + dot(v1[1], kc[1]); + sum3[1].g += dot(v1[0], kd[0]) + dot(v1[1], kd[1]); + sum3[1].b += dot(v1[0], ke[0]) + dot(v1[1], ke[1]); + sum3[1].a += dot(v1[0], kf[0]) + dot(v1[1], kf[1]); + + wx += 8; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = buffer_ld8(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec8 v1 = buffer_ld8(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec8 k0 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 0); + afpvec8 k1 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 1); + afpvec8 k2 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 2); + afpvec8 k3 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 3); + afpvec8 k4 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 4); + afpvec8 k5 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 5); + afpvec8 k6 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 6); + afpvec8 k7 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 7); + + afpvec8 k8 = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 0); + afpvec8 k9 = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 1); + afpvec8 ka = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 2); + afpvec8 kb = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 3); + afpvec8 kc = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 4); + afpvec8 kd = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 5); + afpvec8 ke = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 6); + afpvec8 kf = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 7); + + sum0[0].r += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum0[0].g += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum0[0].b += dot(v0[0], k2[0]) + dot(v0[1], k2[1]); + sum0[0].a += dot(v0[0], k3[0]) + dot(v0[1], k3[1]); + sum0[1].r += dot(v0[0], k4[0]) + dot(v0[1], k4[1]); + sum0[1].g += dot(v0[0], k5[0]) + dot(v0[1], k5[1]); + sum0[1].b += dot(v0[0], k6[0]) + dot(v0[1], k6[1]); + sum0[1].a += dot(v0[0], k7[0]) + dot(v0[1], k7[1]); + + sum1[0].r += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum1[0].g += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + sum1[0].b += dot(v1[0], k2[0]) + dot(v1[1], k2[1]); + sum1[0].a += dot(v1[0], k3[0]) + dot(v1[1], k3[1]); + sum1[1].r += dot(v1[0], k4[0]) + dot(v1[1], k4[1]); + sum1[1].g += dot(v1[0], k5[0]) + dot(v1[1], k5[1]); + sum1[1].b += dot(v1[0], k6[0]) + dot(v1[1], k6[1]); + sum1[1].a += dot(v1[0], k7[0]) + dot(v1[1], k7[1]); + + sum2[0].r += dot(v0[0], k8[0]) + dot(v0[1], k8[1]); + sum2[0].g += dot(v0[0], k9[0]) + dot(v0[1], k9[1]); + sum2[0].b += dot(v0[0], ka[0]) + dot(v0[1], ka[1]); + sum2[0].a += dot(v0[0], kb[0]) + dot(v0[1], kb[1]); + sum2[1].r += dot(v0[0], kc[0]) + dot(v0[1], kc[1]); + sum2[1].g += dot(v0[0], kd[0]) + dot(v0[1], kd[1]); + sum2[1].b += dot(v0[0], ke[0]) + dot(v0[1], ke[1]); + sum2[1].a += dot(v0[0], kf[0]) + dot(v0[1], kf[1]); + + sum3[0].r += dot(v1[0], k8[0]) + dot(v1[1], k8[1]); + sum3[0].g += dot(v1[0], k9[0]) + dot(v1[1], k9[1]); + sum3[0].b += dot(v1[0], ka[0]) + dot(v1[1], ka[1]); + sum3[0].a += dot(v1[0], kb[0]) + dot(v1[1], kb[1]); + sum3[1].r += dot(v1[0], kc[0]) + dot(v1[1], kc[1]); + sum3[1].g += dot(v1[0], kd[0]) + dot(v1[1], kd[1]); + sum3[1].b += dot(v1[0], ke[0]) + dot(v1[1], ke[1]); + sum3[1].a += dot(v1[0], kf[0]) + dot(v1[1], kf[1]); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec8(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec8(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec8(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec8(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st8(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st8(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st8(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st8(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st8(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st8(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/vulkan/shader/convolution1d_pack8to1.comp b/src/layer/vulkan/shader/convolution1d_pack8to1.comp new file mode 100644 index 00000000000..213e01a3c45 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack8to1.comp @@ -0,0 +1,166 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec8 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec8 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfp bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afp sum0 = afp(0.0f); + afp sum1 = afp(0.0f); + afp sum2 = afp(0.0f); + afp sum3 = afp(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld1(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld1(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld1(bias_data, gy2.x); + sum2 = buffer_ld1(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = image3d_ld8(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec8 v1 = image3d_ld8(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec8 k0 = image3d_ld8(weight_blob, ivec3(wx, y, gy2.x)); + afpvec8 k1 = image3d_ld8(weight_blob, ivec3(wx, y, gy2.y)); + + sum0 += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum1 += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum2 += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum3 += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = buffer_ld8(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec8 v1 = buffer_ld8(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec8 k0 = buffer_ld8(weight_data, w_offset.x + x); + afpvec8 k1 = buffer_ld8(weight_data, w_offset.y + x); + + sum0 += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum1 += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum2 += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum3 += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afp(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afp(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afp(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afp(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st1(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st1(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st1(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st1(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st1(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st1(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/vulkan/shader/convolution1d_pack8to4.comp b/src/layer/vulkan/shader/convolution1d_pack8to4.comp new file mode 100644 index 00000000000..d574d1c8ebf --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack8to4.comp @@ -0,0 +1,208 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int w = 0; +layout (constant_id = shape_constant_id_offset + 1) const int h = 0; + +layout (constant_id = shape_constant_id_offset + 2) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 3) const int outh = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec8 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec8 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec4 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int w; + int h; + + int outw; + int outh; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec4 sum0 = afpvec4(0.0f); + afpvec4 sum1 = afpvec4(0.0f); + afpvec4 sum2 = afpvec4(0.0f); + afpvec4 sum3 = afpvec4(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld4(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld4(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld4(bias_data, gy2.x); + sum2 = buffer_ld4(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = image3d_ld8(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec8 v1 = image3d_ld8(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec8 k0 = image3d_ld8(weight_blob, ivec3(wx + 0, y, gy2.x)); + afpvec8 k1 = image3d_ld8(weight_blob, ivec3(wx + 1, y, gy2.x)); + afpvec8 k2 = image3d_ld8(weight_blob, ivec3(wx + 2, y, gy2.x)); + afpvec8 k3 = image3d_ld8(weight_blob, ivec3(wx + 3, y, gy2.x)); + afpvec8 k4 = image3d_ld8(weight_blob, ivec3(wx + 0, y, gy2.y)); + afpvec8 k5 = image3d_ld8(weight_blob, ivec3(wx + 1, y, gy2.y)); + afpvec8 k6 = image3d_ld8(weight_blob, ivec3(wx + 2, y, gy2.y)); + afpvec8 k7 = image3d_ld8(weight_blob, ivec3(wx + 3, y, gy2.y)); + + sum0.r += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum0.g += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum0.b += dot(v0[0], k2[0]) + dot(v0[1], k2[1]); + sum0.a += dot(v0[0], k3[0]) + dot(v0[1], k3[1]); + + sum1.r += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum1.g += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + sum1.b += dot(v1[0], k2[0]) + dot(v1[1], k2[1]); + sum1.a += dot(v1[0], k3[0]) + dot(v1[1], k3[1]); + + sum2.r += dot(v0[0], k4[0]) + dot(v0[1], k4[1]); + sum2.g += dot(v0[0], k5[0]) + dot(v0[1], k5[1]); + sum2.b += dot(v0[0], k6[0]) + dot(v0[1], k6[1]); + sum2.a += dot(v0[0], k7[0]) + dot(v0[1], k7[1]); + + sum3.r += dot(v1[0], k4[0]) + dot(v1[1], k4[1]); + sum3.g += dot(v1[0], k5[0]) + dot(v1[1], k5[1]); + sum3.b += dot(v1[0], k6[0]) + dot(v1[1], k6[1]); + sum3.a += dot(v1[0], k7[0]) + dot(v1[1], k7[1]); + + wx += 4; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = buffer_ld8(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec8 v1 = buffer_ld8(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec8 k0 = buffer_ld8(weight_data, (w_offset.x + x) * 4 + 0); + afpvec8 k1 = buffer_ld8(weight_data, (w_offset.x + x) * 4 + 1); + afpvec8 k2 = buffer_ld8(weight_data, (w_offset.x + x) * 4 + 2); + afpvec8 k3 = buffer_ld8(weight_data, (w_offset.x + x) * 4 + 3); + afpvec8 k4 = buffer_ld8(weight_data, (w_offset.y + x) * 4 + 0); + afpvec8 k5 = buffer_ld8(weight_data, (w_offset.y + x) * 4 + 1); + afpvec8 k6 = buffer_ld8(weight_data, (w_offset.y + x) * 4 + 2); + afpvec8 k7 = buffer_ld8(weight_data, (w_offset.y + x) * 4 + 3); + + sum0.r += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum0.g += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum0.b += dot(v0[0], k2[0]) + dot(v0[1], k2[1]); + sum0.a += dot(v0[0], k3[0]) + dot(v0[1], k3[1]); + + sum1.r += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum1.g += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + sum1.b += dot(v1[0], k2[0]) + dot(v1[1], k2[1]); + sum1.a += dot(v1[0], k3[0]) + dot(v1[1], k3[1]); + + sum2.r += dot(v0[0], k4[0]) + dot(v0[1], k4[1]); + sum2.g += dot(v0[0], k5[0]) + dot(v0[1], k5[1]); + sum2.b += dot(v0[0], k6[0]) + dot(v0[1], k6[1]); + sum2.a += dot(v0[0], k7[0]) + dot(v0[1], k7[1]); + + sum3.r += dot(v1[0], k4[0]) + dot(v1[1], k4[1]); + sum3.g += dot(v1[0], k5[0]) + dot(v1[1], k5[1]); + sum3.b += dot(v1[0], k6[0]) + dot(v1[1], k6[1]); + sum3.a += dot(v1[0], k7[0]) + dot(v1[1], k7[1]); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec4(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec4(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st4(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st4(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st4(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st4(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st4(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st4(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} diff --git a/src/layer/x86/binaryop_x86.cpp b/src/layer/x86/binaryop_x86.cpp index d3f62e09d36..14ad9d5f638 100644 --- a/src/layer/x86/binaryop_x86.cpp +++ b/src/layer/x86/binaryop_x86.cpp @@ -26,8 +26,6 @@ #endif // __AVX__ #endif // __SSE2__ -#include - namespace ncnn { BinaryOp_x86::BinaryOp_x86() diff --git a/src/layer/x86/bnll_x86.cpp b/src/layer/x86/bnll_x86.cpp index e082d79fc48..e2eb995d095 100644 --- a/src/layer/x86/bnll_x86.cpp +++ b/src/layer/x86/bnll_x86.cpp @@ -25,7 +25,6 @@ #endif // __AVX512F__ #endif // __AVX__ #endif // __SSE2__ -#include namespace ncnn { diff --git a/src/layer/x86/convolution_3x3_int8.h b/src/layer/x86/convolution_3x3_int8.h index a5c5dfe4e71..ceaf75b92e1 100644 --- a/src/layer/x86/convolution_3x3_int8.h +++ b/src/layer/x86/convolution_3x3_int8.h @@ -78,833 +78,6 @@ static void conv3x3s1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& } } -static void conv3x3s1_winograd23_transform_kernel_int8_sse(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - kernel_tm.create(4 * 4, inch, outch, (size_t)2u); - - // G - const short ktm[4][3] = { - {2, 0, 0}, - {1, 1, 1}, - {1, -1, 1}, - {0, 0, 2} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[4][3]; - for (int i = 0; i < 4; i++) - { - tmp[i][0] = (short)k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = (short)k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = (short)k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 4; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 4; i++) - { - kernel_tm0[j * 4 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } -} - -static void conv3x3s1_winograd23_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 2n+2, winograd F(2,3) - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 1) / 2 * 2; - outh = (outh + 1) / 2 * 2; - - w = outw + 2; - h = outh + 2; - Option opt_b = opt; - opt_b.blob_allocator = opt.workspace_allocator; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - const int tiles = nColBlocks * nRowBlocks; - - bottom_blob_tm.create(4 * 4, tiles, inch, 2u, opt.workspace_allocator); - - // BT - // const float itm[4][4] = { - // {1.0f, 0.0f, -1.0f, 0.0f}, - // {0.0f, 1.0f, 1.00f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 0.0f}, - // {0.0f, -1.0f, 0.00f, 1.0f} - // }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const signed char* img = bottom_blob_bordered.channel(q); - short* out_tm0 = bottom_blob_tm.channel(q); - - for (int j = 0; j < nColBlocks; j++) - { - const signed char* r0 = img + w * j * 2; - const signed char* r1 = r0 + w; - const signed char* r2 = r1 + w; - const signed char* r3 = r2 + w; - - for (int i = 0; i < nRowBlocks; i++) - { - short d0[4], d1[4], d2[4], d3[4]; - short w0[4], w1[4], w2[4], w3[4]; - short t0[4], t1[4], t2[4], t3[4]; - // load - for (int n = 0; n < 4; n++) - { - d0[n] = r0[n]; - d1[n] = r1[n]; - d2[n] = r2[n]; - d3[n] = r3[n]; - } - // w = B_t * d - for (int n = 0; n < 4; n++) - { - w0[n] = d0[n] - d2[n]; - w1[n] = d1[n] + d2[n]; - w2[n] = d2[n] - d1[n]; - w3[n] = d3[n] - d1[n]; - } - // transpose d to d_t - { - t0[0] = w0[0]; - t1[0] = w0[1]; - t2[0] = w0[2]; - t3[0] = w0[3]; - t0[1] = w1[0]; - t1[1] = w1[1]; - t2[1] = w1[2]; - t3[1] = w1[3]; - t0[2] = w2[0]; - t1[2] = w2[1]; - t2[2] = w2[2]; - t3[2] = w2[3]; - t0[3] = w3[0]; - t1[3] = w3[1]; - t2[3] = w3[2]; - t3[3] = w3[3]; - } - // U = B_t * d_t - for (int n = 0; n < 4; n++) - { - d0[n] = t0[n] - t2[n]; - d1[n] = t1[n] + t2[n]; - d2[n] = t2[n] - t1[n]; - d3[n] = t3[n] - t1[n]; - } - // save to out_tm - for (int n = 0; n < 4; n++) - { - out_tm0[n] = d0[n]; - out_tm0[n + 4] = d1[n]; - out_tm0[n + 8] = d2[n]; - out_tm0[n + 12] = d3[n]; - } - - r0 += 2; - r1 += 2; - r2 += 2; - r3 += 2; - - out_tm0 += 16; - } - } - } - } - bottom_blob_bordered = Mat(); - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - const int tiles = nColBlocks * nRowBlocks; - - top_blob_tm.create(16, tiles, outch, 4u, opt.workspace_allocator); - - int nn_outch = outch >> 2; - int remain_outch_start = nn_outch << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 4; - - Mat out0_tm = top_blob_tm.channel(p); - Mat out1_tm = top_blob_tm.channel(p + 1); - Mat out2_tm = top_blob_tm.channel(p + 2); - Mat out3_tm = top_blob_tm.channel(p + 3); - - const Mat kernel0_tm = kernel_tm.channel(p); - const Mat kernel1_tm = kernel_tm.channel(p + 1); - const Mat kernel2_tm = kernel_tm.channel(p + 2); - const Mat kernel3_tm = kernel_tm.channel(p + 3); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - int* output1_tm = out1_tm.row(i); - int* output2_tm = out2_tm.row(i); - int* output3_tm = out3_tm.row(i); - - int sum0[16] = {0}; - int sum1[16] = {0}; - int sum2[16] = {0}; - int sum3[16] = {0}; - - int q = 0; - for (; q + 3 < inch; q += 4) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* r1 = bottom_blob_tm.channel(q + 1).row(i); - const short* r2 = bottom_blob_tm.channel(q + 2).row(i); - const short* r3 = bottom_blob_tm.channel(q + 3).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel1_tm.row(q); - const short* k2 = kernel2_tm.row(q); - const short* k3 = kernel3_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r1[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r2[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r3[n] * k0[n]; - k0 -= 16 * 3; - - sum1[n] += (int)r0[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r1[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r2[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r3[n] * k1[n]; - k1 -= 16 * 3; - - sum2[n] += (int)r0[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r1[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r2[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r3[n] * k2[n]; - k2 -= 16 * 3; - - sum3[n] += (int)r0[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r1[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r2[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r3[n] * k3[n]; - k3 -= 16 * 3; - } - } - - for (; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel1_tm.row(q); - const short* k2 = kernel2_tm.row(q); - const short* k3 = kernel3_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - sum1[n] += (int)r0[n] * k1[n]; - sum2[n] += (int)r0[n] * k2[n]; - sum3[n] += (int)r0[n] * k3[n]; - } - } - - for (int n = 0; n < 16; n++) - { - output0_tm[n] = sum0[n]; - output1_tm[n] = sum1[n]; - output2_tm[n] = sum2[n]; - output3_tm[n] = sum3[n]; - } - } - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - Mat out0_tm = top_blob_tm.channel(p); - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - - int sum0[16] = {0}; - - int q = 0; - for (; q + 3 < inch; q += 4) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* r1 = bottom_blob_tm.channel(q + 1).row(i); - const short* r2 = bottom_blob_tm.channel(q + 2).row(i); - const short* r3 = bottom_blob_tm.channel(q + 3).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel0_tm.row(q + 1); - const short* k2 = kernel0_tm.row(q + 2); - const short* k3 = kernel0_tm.row(q + 3); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - sum0[n] += (int)r1[n] * k1[n]; - sum0[n] += (int)r2[n] * k2[n]; - sum0[n] += (int)r3[n] * k3[n]; - } - } - - for (; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* k0 = kernel0_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - } - } - - for (int n = 0; n < 16; n++) - { - output0_tm[n] = sum0[n]; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator); - { - // AT - // const float itm[2][4] = { - // {1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 1.0f} - // }; - - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out_tm = top_blob_tm.channel(p); - Mat out = top_blob_bordered.channel(p); - - for (int j = 0; j < nColBlocks; j++) - { - int* outRow0 = out.row(j * 2); - int* outRow1 = out.row(j * 2 + 1); - - for (int i = 0; i < nRowBlocks; i++) - { - int* out_tile = out_tm.row(j * nRowBlocks + i); - - int s0[4], s1[4], s2[4], s3[4]; - int w0[4], w1[4]; - int d0[2], d1[2], d2[2], d3[2]; - int o0[2], o1[2]; - // load - for (int n = 0; n < 4; n++) - { - s0[n] = out_tile[n]; - s1[n] = out_tile[n + 4]; - s2[n] = out_tile[n + 8]; - s3[n] = out_tile[n + 12]; - } - // w = A_T * W - for (int n = 0; n < 4; n++) - { - w0[n] = s0[n] + s1[n] + s2[n]; - w1[n] = s1[n] - s2[n] + s3[n]; - } - // transpose w to w_t - { - d0[0] = w0[0]; - d0[1] = w1[0]; - d1[0] = w0[1]; - d1[1] = w1[1]; - d2[0] = w0[2]; - d2[1] = w1[2]; - d3[0] = w0[3]; - d3[1] = w1[3]; - } - // Y = A_T * w_t - for (int n = 0; n < 2; n++) - { - o0[n] = d0[n] + d1[n] + d2[n]; - o1[n] = d1[n] - d2[n] + d3[n]; - } - // save to top blob tm,why right 2,because the G' = G*2 - outRow0[0] = o0[0] >> 2; - outRow0[1] = o0[1] >> 2; - outRow1[0] = o1[0] >> 2; - outRow1[1] = o1[1] >> 2; - - outRow0 += 2; - outRow1 += 2; - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} - -static void conv3x3s1_winograd43_transform_kernel_int8_sse(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - kernel_tm.create(6 * 6, inch, outch, (size_t)2u); - - // G - // const float ktm[6][3] = { - // { 1.0f/4, 0.0f, 0.0f}, - // { -1.0f/6, -1.0f/6, -1.0f/6}, - // { -1.0f/6, 1.0f/6, -1.0f/6}, - // { 1.0f/24, 1.0f/12, 1.0f/6}, - // { 1.0f/24, -1.0f/12, 1.0f/6}, - // { 0.0f, 0.0f, 1.0f} - // }; - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 24} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } -} - -static void conv3x3s1_winograd43_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2, winograd F(4,3) - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - Option opt_b = opt; - opt_b.blob_allocator = opt.workspace_allocator; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - const int tiles = nColBlocks * nRowBlocks; - - bottom_blob_tm.create(6 * 6, tiles, inch, 2u, opt.workspace_allocator); - - // BT - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r03 + r04 - // 2 = 4 * (r01 - r02) - r03 + r04 - // 3 = -2 * r01 - r02 + 2 * r03 + r04 - // 4 = 2 * r01 - r02 - 2 * r03 + r04 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const signed char* img = bottom_blob_bordered.channel(q); - short* out_tm0 = bottom_blob_tm.channel(q); - - for (int j = 0; j < nColBlocks; j++) - { - const signed char* r0 = img + w * j * 4; - const signed char* r1 = r0 + w; - const signed char* r2 = r1 + w; - const signed char* r3 = r2 + w; - const signed char* r4 = r3 + w; - const signed char* r5 = r4 + w; - - for (int i = 0; i < nRowBlocks; i++) - { - short d0[6], d1[6], d2[6], d3[6], d4[6], d5[6]; - short w0[6], w1[6], w2[6], w3[6], w4[6], w5[6]; - short t0[6], t1[6], t2[6], t3[6], t4[6], t5[6]; - - // load - for (int n = 0; n < 6; n++) - { - d0[n] = r0[n]; - d1[n] = r1[n]; - d2[n] = r2[n]; - d3[n] = r3[n]; - d4[n] = r4[n]; - d5[n] = r5[n]; - } - // w = B_t * d - for (int n = 0; n < 6; n++) - { - w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n]; - w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n]; - w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n]; - w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n]; - w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n]; - w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n]; - } - // transpose d to d_t - { - t0[0] = w0[0]; - t1[0] = w0[1]; - t2[0] = w0[2]; - t3[0] = w0[3]; - t4[0] = w0[4]; - t5[0] = w0[5]; - t0[1] = w1[0]; - t1[1] = w1[1]; - t2[1] = w1[2]; - t3[1] = w1[3]; - t4[1] = w1[4]; - t5[1] = w1[5]; - t0[2] = w2[0]; - t1[2] = w2[1]; - t2[2] = w2[2]; - t3[2] = w2[3]; - t4[2] = w2[4]; - t5[2] = w2[5]; - t0[3] = w3[0]; - t1[3] = w3[1]; - t2[3] = w3[2]; - t3[3] = w3[3]; - t4[3] = w3[4]; - t5[3] = w3[5]; - t0[4] = w4[0]; - t1[4] = w4[1]; - t2[4] = w4[2]; - t3[4] = w4[3]; - t4[4] = w4[4]; - t5[4] = w4[5]; - t0[5] = w5[0]; - t1[5] = w5[1]; - t2[5] = w5[2]; - t3[5] = w5[3]; - t4[5] = w5[4]; - t5[5] = w5[5]; - } - // d = B_t * d_t - for (int n = 0; n < 6; n++) - { - d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n]; - d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n]; - d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n]; - d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n]; - d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n]; - d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n]; - } - // save to out_tm - for (int n = 0; n < 6; n++) - { - out_tm0[n] = d0[n]; - out_tm0[n + 6] = d1[n]; - out_tm0[n + 12] = d2[n]; - out_tm0[n + 18] = d3[n]; - out_tm0[n + 24] = d4[n]; - out_tm0[n + 30] = d5[n]; - } - - r0 += 4; - r1 += 4; - r2 += 4; - r3 += 4; - r4 += 4; - r5 += 4; - - out_tm0 += 36; - } - } - } - } - bottom_blob_bordered = Mat(); - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - const int tiles = nColBlocks * nRowBlocks; - - top_blob_tm.create(36, tiles, outch, 4u, opt.workspace_allocator); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out0_tm = top_blob_tm.channel(p); - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - - int sum0[36] = {0}; - - for (int q = 0; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* k0 = kernel0_tm.row(q); - - for (int n = 0; n < 36; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - } - } - - for (int n = 0; n < 36; n++) - { - output0_tm[n] = sum0[n]; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator); - { - // AT - // const float itm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + r01 + r02 + r03 + r04 - // 1 = r01 - r02 + 2 * (r03 - r04) - // 2 = r01 + r02 + 4 * (r03 + r04) - // 3 = r01 - r02 + 8 * (r03 - r04) + r05 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out_tm = top_blob_tm.channel(p); - Mat out = top_blob_bordered.channel(p); - - for (int j = 0; j < nColBlocks; j++) - { - int* outRow0 = out.row(j * 4); - int* outRow1 = out.row(j * 4 + 1); - int* outRow2 = out.row(j * 4 + 2); - int* outRow3 = out.row(j * 4 + 3); - - for (int i = 0; i < nRowBlocks; i++) - { - int* out_tile = out_tm.row(j * nRowBlocks + i); - - int s0[6], s1[6], s2[6], s3[6], s4[6], s5[6]; - int w0[6], w1[6], w2[6], w3[6]; - int d0[4], d1[4], d2[4], d3[4], d4[4], d5[4]; - int o0[4], o1[4], o2[4], o3[4]; - // load - for (int n = 0; n < 6; n++) - { - s0[n] = out_tile[n]; - s1[n] = out_tile[n + 6]; - s2[n] = out_tile[n + 12]; - s3[n] = out_tile[n + 18]; - s4[n] = out_tile[n + 24]; - s5[n] = out_tile[n + 30]; - } - // w = A_T * W - for (int n = 0; n < 6; n++) - { - w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n]; - w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n]; - w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n]; - w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + s5[n]; - } - // transpose w to w_t - { - d0[0] = w0[0]; - d0[1] = w1[0]; - d0[2] = w2[0]; - d0[3] = w3[0]; - d1[0] = w0[1]; - d1[1] = w1[1]; - d1[2] = w2[1]; - d1[3] = w3[1]; - d2[0] = w0[2]; - d2[1] = w1[2]; - d2[2] = w2[2]; - d2[3] = w3[2]; - d3[0] = w0[3]; - d3[1] = w1[3]; - d3[2] = w2[3]; - d3[3] = w3[3]; - d4[0] = w0[4]; - d4[1] = w1[4]; - d4[2] = w2[4]; - d4[3] = w3[4]; - d5[0] = w0[5]; - d5[1] = w1[5]; - d5[2] = w2[5]; - d5[3] = w3[5]; - } - // Y = A_T * w_t - for (int n = 0; n < 4; n++) - { - o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n]; - o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n]; - o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n]; - o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n]; - } - // save to top blob tm - for (int n = 0; n < 4; n++) - { - outRow0[n] = o0[n] / 576; - outRow1[n] = o1[n] / 576; - outRow2[n] = o2[n] / 576; - outRow3[n] = o3[n] / 576; - } - - outRow0 += 4; - outRow1 += 4; - outRow2 += 4; - outRow3 += 4; - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} - static void conv3x3s2_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Option& opt) { int w = bottom_blob.w; diff --git a/src/layer/x86/convolution_3x3_pack8to1_int8.h b/src/layer/x86/convolution_3x3_pack8to1_int8.h deleted file mode 100644 index d5957faf6d8..00000000000 --- a/src/layer/x86/convolution_3x3_pack8to1_int8.h +++ /dev/null @@ -1,1125 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif -#endif - -static void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif -#endif - - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 4b-8a-inch/8a-36-outch/4b - kernel_tm_pack8to1.create(8 * inch / 8, 36, outch / 4 + outch % 4, (size_t)2u * 4, 4); - - int p = 0; - for (; p + 3 < outch; p += 4) - { - Mat g0 = kernel_tm_pack8to1.channel(p / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 4; i++) - { - for (int j = 0; j < 8; j++) - { - const short* k00 = kernel_tm.channel(p + i).row(q + j); - g00[0] = k00[k]; - g00 += 1; - } - } - } - } - } - for (; p < outch; p++) - { - const Mat k0 = kernel_tm.channel(p); - - Mat g0 = kernel_tm_pack8to1.channel(p / 4 + p % 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 8; i++) - { - g00[0] = k0.row(q + i)[k]; - - g00 += 1; - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avx2(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_xop(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif -#endif - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - // TODO use _mm_cvtepi8_epi16 on sse4.1 - __m128i _r00_01 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r02_03 = _mm_loadu_si128((const __m128i*)(r0 + 16)); - __m128i _r04_05 = _mm_loadu_si128((const __m128i*)(r0 + 32)); - __m128i _extr0001 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r00_01); - __m128i _extr0203 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r02_03); - __m128i _extr0405 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r04_05); - __m128i _r00 = _mm_unpacklo_epi8(_r00_01, _extr0001); - __m128i _r01 = _mm_unpackhi_epi8(_r00_01, _extr0001); - __m128i _r02 = _mm_unpacklo_epi8(_r02_03, _extr0203); - __m128i _r03 = _mm_unpackhi_epi8(_r02_03, _extr0203); - __m128i _r04 = _mm_unpacklo_epi8(_r04_05, _extr0405); - __m128i _r05 = _mm_unpackhi_epi8(_r04_05, _extr0405); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _tmp0m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r00, 2), _r04), _mm_mullo_epi16(_r02, _v5)); - __m128i _tmp1m = _mm_sub_epi16(_mm_add_epi16(_r04, _r03), _mm_slli_epi16(_mm_add_epi16(_r01, _r02), 2)); - __m128i _tmp2m = _mm_add_epi16(_mm_sub_epi16(_r04, _r03), _mm_slli_epi16(_mm_sub_epi16(_r01, _r02), 2)); - __m128i _tmp3m = _mm_sub_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp4m = _mm_add_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp5m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r01, 2), _r05), _mm_mullo_epi16(_r03, _v5)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4m); - _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tm / 6 + j) * 8; - short* r0_tm_1 = r0_tm_0 + tiles * 8; - short* r0_tm_2 = r0_tm_0 + tiles * 16; - short* r0_tm_3 = r0_tm_0 + tiles * 24; - short* r0_tm_4 = r0_tm_0 + tiles * 32; - short* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _r0tm0 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp00, 2), _tmp04), _mm_mullo_epi16(_tmp02, _v5)); - __m128i _r0tm1 = _mm_sub_epi16(_mm_add_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_add_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm2 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm3 = _mm_sub_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm4 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm5 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp01, 2), _tmp05), _mm_mullo_epi16(_tmp03, _v5)); - - _mm_storeu_si128((__m128i*)r0_tm_0, _r0tm0); - _mm_storeu_si128((__m128i*)r0_tm_1, _r0tm1); - _mm_storeu_si128((__m128i*)r0_tm_2, _r0tm2); - _mm_storeu_si128((__m128i*)r0_tm_3, _r0tm3); - _mm_storeu_si128((__m128i*)r0_tm_4, _r0tm4); - _mm_storeu_si128((__m128i*)r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = h_tm / 6 * w_tm / 6; - - // permute - // bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); - Mat bottom_blob_tm2; -#if __AVX2__ - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#else - if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < 36; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - short* tmpptr = tm2.row(i / 4); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - _mm256_storeu_si256((__m256i*)tmpptr, _r0); - _mm256_storeu_si256((__m256i*)(tmpptr + 16), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 32; - } - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2); -#else - short* tmpptr = tm2.row(i / 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - _mm_storeu_si128((__m128i*)(tmpptr + 8), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 16; - } - } - for (; i < tiles; i++) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - short* tmpptr = tm2.row(i / 2 + i % 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, 36, outch, 4u, 1, opt.workspace_allocator); - - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 4; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - - const Mat kernel0_tm = kernel_tm.channel(p / 4); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); - - __m256i _sum04_15 = _mm256_setzero_si256(); - __m256i _sum14_05 = _mm256_setzero_si256(); - __m256i _sum06_17 = _mm256_setzero_si256(); - __m256i _sum16_07 = _mm256_setzero_si256(); - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif - - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m256i _val32 = _mm256_permute4x64_epi64(_val23, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum04_15 = _mm256_dpwssd_epi32(_sum04_15, _val23, _w01); - _sum14_05 = _mm256_dpwssd_epi32(_sum14_05, _val32, _w01); - _sum06_17 = _mm256_dpwssd_epi32(_sum06_17, _val23, _w23); - _sum16_07 = _mm256_dpwssd_epi32(_sum16_07, _val32, _w23); -#else - _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_madd_epi16(_val23, _w01)); - _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_madd_epi16(_val32, _w01)); - _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_madd_epi16(_val23, _w23)); - _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_madd_epi16(_val32, _w23)); -#endif - - r0 += 32; - k0 += 32; - } - - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum04_15, _sum14_05); - _tmp1 = _mm256_unpacklo_epi32(_sum06_17, _sum16_07); - _tmp2 = _mm256_unpackhi_epi32(_sum04_15, _sum14_05); - _tmp3 = _mm256_unpackhi_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum14_05 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum06_17 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum16_07 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum14_05); - _sum06_17 = _mm256_add_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum06_17); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - _sum04_15 = _mm256_permutevar8x32_epi32(_sum04_15, _perm_mask); - - int sum[16]; - _mm256_storeu_si256((__m256i*)sum, _sum00_11); - _mm256_storeu_si256((__m256i*)(sum + 8), _sum04_15); - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm[1] = sum[4]; - output1_tm[1] = sum[5]; - output2_tm[1] = sum[6]; - output3_tm[1] = sum[7]; - output0_tm[2] = sum[8]; - output1_tm[2] = sum[9]; - output2_tm[2] = sum[10]; - output3_tm[2] = sum[11]; - output0_tm[3] = sum[12]; - output1_tm[3] = sum[13]; - output2_tm[3] = sum[14]; - output3_tm[3] = sum[15]; - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); -#else - __m128i _sum00 = _mm_setzero_si128(); - __m128i _sum01 = _mm_setzero_si128(); - __m128i _sum02 = _mm_setzero_si128(); - __m128i _sum03 = _mm_setzero_si128(); - __m128i _sum10 = _mm_setzero_si128(); - __m128i _sum11 = _mm_setzero_si128(); - __m128i _sum12 = _mm_setzero_si128(); - __m128i _sum13 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { -#if __AVX2__ - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif -#else - // 0 1 2 3 4 5 6 7 - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum00 = _mm_maddd_epi16(_val0, _w0, _sum00); - _sum01 = _mm_maddd_epi16(_val0, _w1, _sum01); - _sum02 = _mm_maddd_epi16(_val0, _w2, _sum02); - _sum03 = _mm_maddd_epi16(_val0, _w3, _sum03); - _sum10 = _mm_maddd_epi16(_val1, _w0, _sum10); - _sum11 = _mm_maddd_epi16(_val1, _w1, _sum11); - _sum12 = _mm_maddd_epi16(_val1, _w2, _sum12); - _sum13 = _mm_maddd_epi16(_val1, _w3, _sum13); -#else - _sum00 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum00); - _sum01 = _mm_add_epi32(_mm_madd_epi16(_val0, _w1), _sum01); - _sum02 = _mm_add_epi32(_mm_madd_epi16(_val0, _w2), _sum02); - _sum03 = _mm_add_epi32(_mm_madd_epi16(_val0, _w3), _sum03); - _sum10 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum10); - _sum11 = _mm_add_epi32(_mm_madd_epi16(_val1, _w1), _sum11); - _sum12 = _mm_add_epi32(_mm_madd_epi16(_val1, _w2), _sum12); - _sum13 = _mm_add_epi32(_mm_madd_epi16(_val1, _w3), _sum13); -#endif -#endif - - r0 += 16; - k0 += 32; - } - -#if __AVX2__ - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - - int sum[8]; - _mm256_storeu_si256((__m256i*)sum, _sum00_11); -#else - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); - _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); - _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); - _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); - _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); - _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); - _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); - _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); - _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00 = _mm_add_epi32(_sum00, _sum01); - _sum02 = _mm_add_epi32(_sum02, _sum03); - _sum10 = _mm_add_epi32(_sum10, _sum11); - _sum12 = _mm_add_epi32(_sum12, _sum13); - - _sum00 = _mm_add_epi32(_sum00, _sum02); - _sum10 = _mm_add_epi32(_sum10, _sum12); - - int sum[8]; - _mm_storeu_si128((__m128i*)sum, _sum00); - _mm_storeu_si128((__m128i*)(sum + 4), _sum10); -#endif - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm[1] = sum[4]; - output1_tm[1] = sum[5]; - output2_tm[1] = sum[6]; - output3_tm[1] = sum[7]; - output0_tm += 2; - output1_tm += 2; - output2_tm += 2; - output3_tm += 2; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum0_1 = _mm256_setzero_si256(); - __m256i _sum2_3 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); - __m128i _sum2 = _mm_setzero_si128(); - __m128i _sum3 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 - __m128i _val = _mm_loadu_si128((const __m128i*)r0); - -#if __AVX2__ - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _valval = _mm256_inserti128_si256(_mm256_castsi128_si256(_val), _val, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0_1 = _mm256_dpwssd_epi32(_sum0_1, _valval, _w01); - _sum2_3 = _mm256_dpwssd_epi32(_sum2_3, _valval, _w23); -#else - _sum0_1 = _mm256_add_epi32(_sum0_1, _mm256_madd_epi16(_valval, _w01)); - _sum2_3 = _mm256_add_epi32(_sum2_3, _mm256_madd_epi16(_valval, _w23)); -#endif -#else - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val, _w1, _sum1); - _sum2 = _mm_maddd_epi16(_val, _w2, _sum2); - _sum3 = _mm_maddd_epi16(_val, _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val, _w1), _sum1); - _sum2 = _mm_add_epi32(_mm_madd_epi16(_val, _w2), _sum2); - _sum3 = _mm_add_epi32(_mm_madd_epi16(_val, _w3), _sum3); -#endif -#endif - - r0 += 8; - k0 += 32; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum0_1, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum0_1, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum2_3, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum2_3, 1); -#endif - - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); - _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); - _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); - _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum0 = _mm_add_epi32(_sum0, _sum1); - _sum2 = _mm_add_epi32(_sum2, _sum3); - - _sum0 = _mm_add_epi32(_sum0, _sum2); - - int sum[4]; - _mm_storeu_si128((__m128i*)sum, _sum0); - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - } - } - } - - remain_outch_start += nn_outch << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p / 4 + p % 4); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - __m256i _sum01 = _mm256_setzero_si256(); - __m256i _sum23 = _mm256_setzero_si256(); - - for (int q = 0; q < inch; q++) - { - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m256i _w01 = _mm256_inserti128_si256(_mm256_castsi128_si256(_w0), _w0, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum01 = _mm256_dpwssd_epi32(_sum01, _val01, _w01); - _sum23 = _mm256_dpwssd_epi32(_sum23, _val23, _w01); -#else - _sum01 = _mm256_add_epi32(_sum01, _mm256_madd_epi16(_val01, _w01)); - _sum23 = _mm256_add_epi32(_sum23, _mm256_madd_epi16(_val23, _w01)); -#endif - - k0 += 8; - r0 += 32; - } - - __m128i _sum0 = _mm256_extracti128_si256(_sum01, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum01, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum23, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum23, 1); - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm[1] = _mm_reduce_add_epi32(_sum1); - output0_tm[2] = _mm_reduce_add_epi32(_sum2); - output0_tm[3] = _mm_reduce_add_epi32(_sum3); - output0_tm += 4; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - -#if __AVX2__ - __m256i _sum01 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); -#endif - - for (int q = 0; q < inch; q++) - { -#if __AVX2__ - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m256i _w01 = _mm256_inserti128_si256(_mm256_castsi128_si256(_w0), _w0, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum01 = _mm256_dpwssd_epi32(_sum01, _val01, _w01); -#else - _sum01 = _mm256_add_epi32(_sum01, _mm256_madd_epi16(_val01, _w01)); -#endif -#else - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val0, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val1, _w0, _sum1); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum1); -#endif -#endif - - k0 += 8; - r0 += 16; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum01, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum01, 1); -#endif - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm[1] = _mm_reduce_add_epi32(_sum1); - output0_tm += 2; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - __m128i _sum0 = _mm_setzero_si128(); - - for (int q = 0; q < inch; q++) - { - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val0, _w0, _sum0); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum0); -#endif - - k0 += 8; - r0 += 8; - } - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm++; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); - } - { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - int tmp[4][6]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, 4u, 1, opt.workspace_allocator); - - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tm / 6 + j) * 1; - const int* output0_tm_1 = output0_tm_0 + tiles * 1; - const int* output0_tm_2 = output0_tm_0 + tiles * 2; - const int* output0_tm_3 = output0_tm_0 + tiles * 3; - const int* output0_tm_4 = output0_tm_0 + tiles * 4; - const int* output0_tm_5 = output0_tm_0 + tiles * 5; - - int* output0 = out0.row(i * 4) + j * 4; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - // TODO sse optimize - for (int m = 0; m < 5; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = output0_tm_0[0] + tmp02a + tmp02b; - tmp[1][m] = tmp13a + tmp13b * 2; - tmp[2][m] = tmp02a + tmp02b * 4; - tmp[3][m] = output0_tm_5[0] * 4 + tmp13a + tmp13b * 8; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - for (int m = 5; m < 6; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = (output0_tm_0[0] + tmp02a + tmp02b) * 4; - tmp[1][m] = (tmp13a + tmp13b * 2) * 4; - tmp[2][m] = (tmp02a + tmp02b * 4) * 4; - tmp[3][m] = (output0_tm_5[0] * 4 + tmp13a + tmp13b * 8) * 4; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - - for (int m = 0; m < 4; m++) - { - const int* tmp0 = tmp[m]; - - int tmp02a = tmp0[1] + tmp0[2]; - int tmp13a = tmp0[1] - tmp0[2]; - - int tmp02b = tmp0[3] + tmp0[4]; - int tmp13b = tmp0[3] - tmp0[4]; - - output0[0] = (tmp0[0] + tmp02a + tmp02b) / 576; - output0[1] = (tmp13a + tmp13b * 2) / 576; - output0[2] = (tmp02a + tmp02b * 4) / 576; - output0[3] = (tmp0[5] + tmp13a + tmp13b * 8) / 576; - - output0 += outw; - } - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/x86/convolution_3x3_pack8to4_int8.h b/src/layer/x86/convolution_3x3_pack8to4_int8.h deleted file mode 100644 index 2bb48ce1903..00000000000 --- a/src/layer/x86/convolution_3x3_pack8to4_int8.h +++ /dev/null @@ -1,945 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif -#endif - -static void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif -#endif - - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 4b-8a-inch/8a-36-outch/4b - kernel_tm_pack8.create(inch / 8, 36, outch / 4, (size_t)2u * 32, 32); - - int q = 0; - for (; q + 3 < outch; q += 4) - { - Mat g0 = kernel_tm_pack8.channel(q / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int p = 0; p + 7 < inch; p += 8) - { - for (int i = 0; i < 4; i++) - { - for (int j = 0; j < 8; j++) - { - const short* k00 = kernel_tm.channel(q + i).row(p + j); - g00[0] = k00[k]; - g00 += 1; - } - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avx2(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_xop(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif -#endif - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - // TODO use _mm_cvtepi8_epi16 on sse4.1 - __m128i _r00_01 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r02_03 = _mm_loadu_si128((const __m128i*)(r0 + 16)); - __m128i _r04_05 = _mm_loadu_si128((const __m128i*)(r0 + 32)); - __m128i _extr0001 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r00_01); - __m128i _extr0203 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r02_03); - __m128i _extr0405 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r04_05); - __m128i _r00 = _mm_unpacklo_epi8(_r00_01, _extr0001); - __m128i _r01 = _mm_unpackhi_epi8(_r00_01, _extr0001); - __m128i _r02 = _mm_unpacklo_epi8(_r02_03, _extr0203); - __m128i _r03 = _mm_unpackhi_epi8(_r02_03, _extr0203); - __m128i _r04 = _mm_unpacklo_epi8(_r04_05, _extr0405); - __m128i _r05 = _mm_unpackhi_epi8(_r04_05, _extr0405); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _tmp0m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r00, 2), _r04), _mm_mullo_epi16(_r02, _v5)); - __m128i _tmp1m = _mm_sub_epi16(_mm_add_epi16(_r04, _r03), _mm_slli_epi16(_mm_add_epi16(_r01, _r02), 2)); - __m128i _tmp2m = _mm_add_epi16(_mm_sub_epi16(_r04, _r03), _mm_slli_epi16(_mm_sub_epi16(_r01, _r02), 2)); - __m128i _tmp3m = _mm_sub_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp4m = _mm_add_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp5m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r01, 2), _r05), _mm_mullo_epi16(_r03, _v5)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4m); - _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tm / 6 + j) * 8; - short* r0_tm_1 = r0_tm_0 + tiles * 8; - short* r0_tm_2 = r0_tm_0 + tiles * 16; - short* r0_tm_3 = r0_tm_0 + tiles * 24; - short* r0_tm_4 = r0_tm_0 + tiles * 32; - short* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _r0tm0 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp00, 2), _tmp04), _mm_mullo_epi16(_tmp02, _v5)); - __m128i _r0tm1 = _mm_sub_epi16(_mm_add_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_add_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm2 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm3 = _mm_sub_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm4 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm5 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp01, 2), _tmp05), _mm_mullo_epi16(_tmp03, _v5)); - - _mm_storeu_si128((__m128i*)r0_tm_0, _r0tm0); - _mm_storeu_si128((__m128i*)r0_tm_1, _r0tm1); - _mm_storeu_si128((__m128i*)r0_tm_2, _r0tm2); - _mm_storeu_si128((__m128i*)r0_tm_3, _r0tm3); - _mm_storeu_si128((__m128i*)r0_tm_4, _r0tm4); - _mm_storeu_si128((__m128i*)r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = h_tm / 6 * w_tm / 6; - - // permute - // bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); - Mat bottom_blob_tm2; -#if __AVX2__ - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#else - if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < 36; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - short* tmpptr = tm2.row(i / 4); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - _mm256_storeu_si256((__m256i*)tmpptr, _r0); - _mm256_storeu_si256((__m256i*)(tmpptr + 16), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 32; - } - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2); -#else - short* tmpptr = tm2.row(i / 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - _mm_storeu_si128((__m128i*)(tmpptr + 8), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 16; - } - } - for (; i < tiles; i++) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - short* tmpptr = tm2.row(i / 2 + i % 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, 36, outch, 4u * 4, 4, opt.workspace_allocator); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); - - __m256i _sum04_15 = _mm256_setzero_si256(); - __m256i _sum14_05 = _mm256_setzero_si256(); - __m256i _sum06_17 = _mm256_setzero_si256(); - __m256i _sum16_07 = _mm256_setzero_si256(); - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif - - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m256i _val32 = _mm256_permute4x64_epi64(_val23, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum04_15 = _mm256_dpwssd_epi32(_sum04_15, _val23, _w01); - _sum14_05 = _mm256_dpwssd_epi32(_sum14_05, _val32, _w01); - _sum06_17 = _mm256_dpwssd_epi32(_sum06_17, _val23, _w23); - _sum16_07 = _mm256_dpwssd_epi32(_sum16_07, _val32, _w23); -#else - _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_madd_epi16(_val23, _w01)); - _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_madd_epi16(_val32, _w01)); - _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_madd_epi16(_val23, _w23)); - _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_madd_epi16(_val32, _w23)); -#endif - - r0 += 32; - k0 += 32; - } - - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum04_15, _sum14_05); - _tmp1 = _mm256_unpacklo_epi32(_sum06_17, _sum16_07); - _tmp2 = _mm256_unpackhi_epi32(_sum04_15, _sum14_05); - _tmp3 = _mm256_unpackhi_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum14_05 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum06_17 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum16_07 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum14_05); - _sum06_17 = _mm256_add_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum06_17); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - _sum04_15 = _mm256_permutevar8x32_epi32(_sum04_15, _perm_mask); - - _mm256_storeu_si256((__m256i*)output0_tm, _sum00_11); - _mm256_storeu_si256((__m256i*)(output0_tm + 8), _sum04_15); - output0_tm += 16; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); -#else - __m128i _sum00 = _mm_setzero_si128(); - __m128i _sum01 = _mm_setzero_si128(); - __m128i _sum02 = _mm_setzero_si128(); - __m128i _sum03 = _mm_setzero_si128(); - __m128i _sum10 = _mm_setzero_si128(); - __m128i _sum11 = _mm_setzero_si128(); - __m128i _sum12 = _mm_setzero_si128(); - __m128i _sum13 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { -#if __AVX2__ - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif -#else - // 0 1 2 3 4 5 6 7 - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum00 = _mm_maddd_epi16(_val0, _w0, _sum00); - _sum01 = _mm_maddd_epi16(_val0, _w1, _sum01); - _sum02 = _mm_maddd_epi16(_val0, _w2, _sum02); - _sum03 = _mm_maddd_epi16(_val0, _w3, _sum03); - _sum10 = _mm_maddd_epi16(_val1, _w0, _sum10); - _sum11 = _mm_maddd_epi16(_val1, _w1, _sum11); - _sum12 = _mm_maddd_epi16(_val1, _w2, _sum12); - _sum13 = _mm_maddd_epi16(_val1, _w3, _sum13); -#else - _sum00 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum00); - _sum01 = _mm_add_epi32(_mm_madd_epi16(_val0, _w1), _sum01); - _sum02 = _mm_add_epi32(_mm_madd_epi16(_val0, _w2), _sum02); - _sum03 = _mm_add_epi32(_mm_madd_epi16(_val0, _w3), _sum03); - _sum10 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum10); - _sum11 = _mm_add_epi32(_mm_madd_epi16(_val1, _w1), _sum11); - _sum12 = _mm_add_epi32(_mm_madd_epi16(_val1, _w2), _sum12); - _sum13 = _mm_add_epi32(_mm_madd_epi16(_val1, _w3), _sum13); -#endif -#endif - - r0 += 16; - k0 += 32; - } - -#if __AVX2__ - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - - _mm256_storeu_si256((__m256i*)output0_tm, _sum00_11); -#else - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); - _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); - _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); - _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); - _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); - _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); - _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); - _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); - _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00 = _mm_add_epi32(_sum00, _sum01); - _sum02 = _mm_add_epi32(_sum02, _sum03); - _sum10 = _mm_add_epi32(_sum10, _sum11); - _sum12 = _mm_add_epi32(_sum12, _sum13); - - _sum00 = _mm_add_epi32(_sum00, _sum02); - _sum10 = _mm_add_epi32(_sum10, _sum12); - - _mm_storeu_si128((__m128i*)output0_tm, _sum00); - _mm_storeu_si128((__m128i*)(output0_tm + 4), _sum10); -#endif - output0_tm += 8; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum0_1 = _mm256_setzero_si256(); - __m256i _sum2_3 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); - __m128i _sum2 = _mm_setzero_si128(); - __m128i _sum3 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 - __m128i _val = _mm_loadu_si128((const __m128i*)r0); -#if __AVX2__ - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _valval = _mm256_inserti128_si256(_mm256_castsi128_si256(_val), _val, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0_1 = _mm256_dpwssd_epi32(_sum0_1, _valval, _w01); - _sum2_3 = _mm256_dpwssd_epi32(_sum2_3, _valval, _w23); -#else - _sum0_1 = _mm256_add_epi32(_sum0_1, _mm256_madd_epi16(_valval, _w01)); - _sum2_3 = _mm256_add_epi32(_sum2_3, _mm256_madd_epi16(_valval, _w23)); -#endif -#else - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val, _w1, _sum1); - _sum2 = _mm_maddd_epi16(_val, _w2, _sum2); - _sum3 = _mm_maddd_epi16(_val, _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val, _w1), _sum1); - _sum2 = _mm_add_epi32(_mm_madd_epi16(_val, _w2), _sum2); - _sum3 = _mm_add_epi32(_mm_madd_epi16(_val, _w3), _sum3); -#endif -#endif - - r0 += 8; - k0 += 32; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum0_1, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum0_1, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum2_3, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum2_3, 1); -#endif - - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); - _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); - _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); - _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum0 = _mm_add_epi32(_sum0, _sum1); - _sum2 = _mm_add_epi32(_sum2, _sum3); - - _sum0 = _mm_add_epi32(_sum0, _sum2); - - _mm_storeu_si128((__m128i*)output0_tm, _sum0); - output0_tm += 4; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u * 4, 4, opt.workspace_allocator); - } - { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - int tmp[4][6][4]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tm / 6 + j) * 4; - const int* output0_tm_1 = output0_tm_0 + tiles * 4; - const int* output0_tm_2 = output0_tm_0 + tiles * 8; - const int* output0_tm_3 = output0_tm_0 + tiles * 12; - const int* output0_tm_4 = output0_tm_0 + tiles * 16; - const int* output0_tm_5 = output0_tm_0 + tiles * 20; - - int* output0 = out0.row(i * 4) + (j * 4) * 4; - - // TODO sse optimize - for (int m = 0; m < 5; m++) - { - __m128i _out0tm0 = _mm_loadu_si128((const __m128i*)output0_tm_0); - __m128i _out0tm1 = _mm_loadu_si128((const __m128i*)output0_tm_1); - __m128i _out0tm2 = _mm_loadu_si128((const __m128i*)output0_tm_2); - __m128i _out0tm3 = _mm_loadu_si128((const __m128i*)output0_tm_3); - __m128i _out0tm4 = _mm_loadu_si128((const __m128i*)output0_tm_4); - __m128i _out0tm5 = _mm_loadu_si128((const __m128i*)output0_tm_5); - - __m128i _tmp02a = _mm_add_epi32(_out0tm1, _out0tm2); - __m128i _tmp13a = _mm_sub_epi32(_out0tm1, _out0tm2); - - __m128i _tmp02b = _mm_add_epi32(_out0tm3, _out0tm4); - __m128i _tmp13b = _mm_sub_epi32(_out0tm3, _out0tm4); - - __m128i _tmp0m = _mm_add_epi32(_mm_add_epi32(_out0tm0, _tmp02a), _tmp02b); - __m128i _tmp1m = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _tmp2m = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _tmp3m = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_out0tm5, 2)), _mm_slli_epi32(_tmp13b, 3)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - for (int m = 5; m < 6; m++) - { - __m128i _out0tm0 = _mm_loadu_si128((const __m128i*)output0_tm_0); - __m128i _out0tm1 = _mm_loadu_si128((const __m128i*)output0_tm_1); - __m128i _out0tm2 = _mm_loadu_si128((const __m128i*)output0_tm_2); - __m128i _out0tm3 = _mm_loadu_si128((const __m128i*)output0_tm_3); - __m128i _out0tm4 = _mm_loadu_si128((const __m128i*)output0_tm_4); - __m128i _out0tm5 = _mm_loadu_si128((const __m128i*)output0_tm_5); - - __m128i _tmp02a = _mm_add_epi32(_out0tm1, _out0tm2); - __m128i _tmp13a = _mm_sub_epi32(_out0tm1, _out0tm2); - - __m128i _tmp02b = _mm_add_epi32(_out0tm3, _out0tm4); - __m128i _tmp13b = _mm_sub_epi32(_out0tm3, _out0tm4); - - __m128i _tmp0m = _mm_add_epi32(_mm_add_epi32(_out0tm0, _tmp02a), _tmp02b); - __m128i _tmp1m = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _tmp2m = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _tmp3m = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_out0tm5, 2)), _mm_slli_epi32(_tmp13b, 3)); - - _tmp0m = _mm_slli_epi32(_tmp0m, 2); - _tmp1m = _mm_slli_epi32(_tmp1m, 2); - _tmp2m = _mm_slli_epi32(_tmp2m, 2); - _tmp3m = _mm_slli_epi32(_tmp3m, 2); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - - for (int m = 0; m < 4; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _tmp02a = _mm_add_epi32(_tmp01, _tmp02); - __m128i _tmp13a = _mm_sub_epi32(_tmp01, _tmp02); - - __m128i _tmp02b = _mm_add_epi32(_tmp03, _tmp04); - __m128i _tmp13b = _mm_sub_epi32(_tmp03, _tmp04); - - __m128i _out00 = _mm_add_epi32(_mm_add_epi32(_tmp00, _tmp02a), _tmp02b); - __m128i _out01 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _out02 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _out03 = _mm_add_epi32(_mm_add_epi32(_tmp05, _tmp13a), _mm_slli_epi32(_tmp13b, 3)); - - // TODO use integer trick for division by 576 - __m128 _v576 = _mm_set1_ps(1.0 / 576); - _out00 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out00), _v576)); - _out01 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out01), _v576)); - _out02 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out02), _v576)); - _out03 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out03), _v576)); - - _mm_storeu_si128((__m128i*)output0, _out00); - _mm_storeu_si128((__m128i*)(output0 + 4), _out01); - _mm_storeu_si128((__m128i*)(output0 + 8), _out02); - _mm_storeu_si128((__m128i*)(output0 + 12), _out03); - - output0 += outw * 4; - } - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/x86/convolution_3x3_winograd_int8.h b/src/layer/x86/convolution_3x3_winograd_int8.h new file mode 100644 index 00000000000..8c7b891b0dd --- /dev/null +++ b/src/layer/x86/convolution_3x3_winograd_int8.h @@ -0,0 +1,6407 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ +void conv3x3s1_winograd23_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +void conv3x3s1_winograd23_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +void conv3x3s1_winograd23_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt); +void conv3x3s1_winograd23_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt); +void conv3x3s1_winograd43_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +void conv3x3s1_winograd23_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif +#endif + +static void pack_A_tile_int8(const Mat& A, Mat& AT, int batch, int max_ii, int max_kk) +{ + const int N = max_kk * batch; + + for (int b = 0; b < batch; b++) + { + short* pp = AT.row(b); + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + pp[8] = p0[N * 4]; + pp[9] = p0[N * 4 + batch]; + pp[10] = p0[N * 5]; + pp[11] = p0[N * 5 + batch]; + pp[12] = p0[N * 6]; + pp[13] = p0[N * 6 + batch]; + pp[14] = p0[N * 7]; + pp[15] = p0[N * 7 + batch]; + pp[16] = p0[N * 8]; + pp[17] = p0[N * 8 + batch]; + pp[18] = p0[N * 9]; + pp[19] = p0[N * 9 + batch]; + pp[20] = p0[N * 10]; + pp[21] = p0[N * 10 + batch]; + pp[22] = p0[N * 11]; + pp[23] = p0[N * 11 + batch]; + pp[24] = p0[N * 12]; + pp[25] = p0[N * 12 + batch]; + pp[26] = p0[N * 13]; + pp[27] = p0[N * 13 + batch]; + pp[28] = p0[N * 14]; + pp[29] = p0[N * 14 + batch]; + pp[30] = p0[N * 15]; + pp[31] = p0[N * 15 + batch]; + p0 += batch * 2; + pp += 32; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + pp[4] = p0[N * 4]; + pp[5] = p0[N * 5]; + pp[6] = p0[N * 6]; + pp[7] = p0[N * 7]; + pp[8] = p0[N * 8]; + pp[9] = p0[N * 9]; + pp[10] = p0[N * 10]; + pp[11] = p0[N * 11]; + pp[12] = p0[N * 12]; + pp[13] = p0[N * 13]; + pp[14] = p0[N * 14]; + pp[15] = p0[N * 15]; + p0 += batch; + pp += 16; + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + pp[8] = p0[N * 4]; + pp[9] = p0[N * 4 + batch]; + pp[10] = p0[N * 5]; + pp[11] = p0[N * 5 + batch]; + pp[12] = p0[N * 6]; + pp[13] = p0[N * 6 + batch]; + pp[14] = p0[N * 7]; + pp[15] = p0[N * 7 + batch]; + p0 += batch * 2; + pp += 16; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + pp[4] = p0[N * 4]; + pp[5] = p0[N * 5]; + pp[6] = p0[N * 6]; + pp[7] = p0[N * 7]; + p0 += batch; + pp += 8; + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + p0 += batch * 2; + pp += 8; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + p0 += batch; + pp += 4; + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + p0 += batch * 2; + pp += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + p0 += batch; + pp += 2; + } + } + for (; ii < max_ii; ii++) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += batch; + pp += 1; + } + } + } +} + +static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int batch, int max_jj, int max_kk, int nT) +{ + #pragma omp parallel for num_threads(nT) + for (int b = 0; b < batch; b++) + { + short* pp = BT.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _r4 = _mm512_loadu_si512((const __m512i*)(p0 + 128)); + __m512i _r5 = _mm512_loadu_si512((const __m512i*)(p0 + 160)); + __m512i _r6 = _mm512_loadu_si512((const __m512i*)(p0 + 192)); + __m512i _r7 = _mm512_loadu_si512((const __m512i*)(p0 + 224)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_r4, _r6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_r4, _r6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_r5, _r7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_r5, _r7, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _r4 = _mm512_unpacklo_epi32(_tmp4, _tmp5); + _r5 = _mm512_unpackhi_epi32(_tmp4, _tmp5); + _r6 = _mm512_unpacklo_epi32(_tmp6, _tmp7); + _r7 = _mm512_unpackhi_epi32(_tmp6, _tmp7); + _tmp0 = _mm512_unpacklo_epi64(_r0, _r2); + _tmp1 = _mm512_unpackhi_epi64(_r0, _r2); + _tmp2 = _mm512_unpacklo_epi64(_r1, _r3); + _tmp3 = _mm512_unpackhi_epi64(_r1, _r3); + _tmp4 = _mm512_unpacklo_epi64(_r4, _r6); + _tmp5 = _mm512_unpackhi_epi64(_r4, _r6); + _tmp6 = _mm512_unpacklo_epi64(_r5, _r7); + _tmp7 = _mm512_unpackhi_epi64(_r5, _r7); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp4, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp1, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_i32x4(_tmp2, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_i32x4(_tmp3, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _r4 = _mm512_shuffle_i32x4(_tmp0, _tmp4, _MM_SHUFFLE(3, 1, 3, 1)); + _r5 = _mm512_shuffle_i32x4(_tmp1, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _r6 = _mm512_shuffle_i32x4(_tmp2, _tmp6, _MM_SHUFFLE(3, 1, 3, 1)); + _r7 = _mm512_shuffle_i32x4(_tmp3, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + _mm512_storeu_si512((__m512i*)(pp + 128), _r4); + _mm512_storeu_si512((__m512i*)(pp + 160), _r5); + _mm512_storeu_si512((__m512i*)(pp + 192), _r6); + _mm512_storeu_si512((__m512i*)(pp + 224), _r7); + p0 += max_jj * batch * 16; + pp += 256; + } + p0 -= (b * max_jj + jj) * 16; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r2, _r3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r2, _r3, _MM_SHUFFLE(3, 1, 3, 1)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _tmp0 = _mm512_permutex_epi64(_r0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm512_permutex_epi64(_r1, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp2 = _mm512_permutex_epi64(_r2, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp3 = _mm512_permutex_epi64(_r3, _MM_SHUFFLE(3, 1, 2, 0)); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 3, 1)); + _r2 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + p0 += max_jj * batch * 8; + pp += 128; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + _mm512_storeu_si512((__m512i*)pp, _r0); + p0 += max_jj * batch * 2; + pp += 32; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); + _mm256_store_si256((__m256i*)pp, _r0); + p0 += max_jj * batch; + pp += 16; + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* p0 = B; + + int kk = 0; +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _tmp0 = _mm512_unpacklo_epi64(_r0, _r2); + _tmp1 = _mm512_unpackhi_epi64(_r0, _r2); + _tmp2 = _mm512_unpacklo_epi64(_r1, _r3); + _tmp3 = _mm512_unpackhi_epi64(_r1, _r3); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _r3 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + p0 += max_jj * batch * 16; + pp += 128; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if __AVX__ + __m256 _r0 = _mm256_loadu_ps((const float*)p0); + __m256 _r1 = _mm256_loadu_ps((const float*)(p0 + 16)); + __m256 _r2 = _mm256_loadu_ps((const float*)(p0 + 32)); + __m256 _r3 = _mm256_loadu_ps((const float*)(p0 + 48)); + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp2 = _mm256_permute2f128_ps(_r1, _r3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_r1, _r3, _MM_SHUFFLE(0, 3, 0, 1)); + _r0 = _mm256_unpacklo_ps(_tmp0, _tmp1); + _r1 = _mm256_unpackhi_ps(_tmp0, _tmp1); + _r2 = _mm256_unpacklo_ps(_tmp2, _tmp3); + _r3 = _mm256_unpackhi_ps(_tmp2, _tmp3); + _tmp0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_r0), _mm256_castps_pd(_r2))); + _tmp1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_r0), _mm256_castps_pd(_r2))); + _tmp2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_r1), _mm256_castps_pd(_r3))); + _tmp3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_r1), _mm256_castps_pd(_r3))); + _mm256_storeu_ps((float*)pp, _tmp0); + _mm256_storeu_ps((float*)(pp + 16), _tmp1); + _mm256_storeu_ps((float*)(pp + 32), _tmp2); + _mm256_storeu_ps((float*)(pp + 48), _tmp3); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_load_si128((const __m128i*)(p0 + 8 * 2)); + __m128i _r3 = _mm_load_si128((const __m128i*)(p0 + 8 * 3)); + __m128i _r4 = _mm_load_si128((const __m128i*)(p0 + 8 * 4)); + __m128i _r5 = _mm_load_si128((const __m128i*)(p0 + 8 * 5)); + __m128i _r6 = _mm_load_si128((const __m128i*)(p0 + 8 * 6)); + __m128i _r7 = _mm_load_si128((const __m128i*)(p0 + 8 * 7)); + transpose4x8_epi32(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm_store_si128((__m128i*)pp, _r0); + _mm_store_si128((__m128i*)(pp + 8), _r1); + _mm_store_si128((__m128i*)(pp + 8 * 2), _r2); + _mm_store_si128((__m128i*)(pp + 8 * 3), _r3); + _mm_store_si128((__m128i*)(pp + 8 * 4), _r4); + _mm_store_si128((__m128i*)(pp + 8 * 5), _r5); + _mm_store_si128((__m128i*)(pp + 8 * 6), _r6); + _mm_store_si128((__m128i*)(pp + 8 * 7), _r7); +#endif // __AVX__ + p0 += max_jj * batch * 8; + pp += 64; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX__ + __m256 _r0 = _mm256_loadu_ps((const float*)p0); + _mm256_storeu_ps((float*)pp, _r0); +#else + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + _mm_store_si128((__m128i*)pp, _r0); + _mm_store_si128((__m128i*)(pp + 8), _r1); +#endif // __AVX__ + p0 += max_jj * batch * 2; + pp += 16; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + _mm_store_si128((__m128i*)pp, _r0); + p0 += max_jj * batch; + pp += 8; + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* p0 = B; + + int kk = 0; +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r0 = _mm512_permutex_epi64(_r0, _MM_SHUFFLE(3, 1, 2, 0)); + _r1 = _mm512_permutex_epi64(_r1, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(1, 0, 1, 0)); + _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi64(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi64(_tmp0, _tmp1); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + p0 += max_jj * batch * 16; + pp += 64; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_load_si128((const __m128i*)(p0 + 8 * 2)); + __m128i _r3 = _mm_load_si128((const __m128i*)(p0 + 8 * 3)); + transpose4x4_epi32(_r0, _r1, _r2, _r3); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); + _mm_storeu_si128((__m128i*)(pp + 8 * 2), _r2); + _mm_storeu_si128((__m128i*)(pp + 8 * 3), _r3); + p0 += max_jj * batch * 8; + pp += 32; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + _mm_storeu_si128((__m128i*)pp, _r0); + p0 += max_jj * batch * 2; + pp += 8; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + p0 += max_jj * batch; + pp += 4; + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* p0 = B; + + int kk = 0; +#if __SSE2__ +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)p0); + __m256i _r1 = _mm256_load_si256((const __m256i*)(p0 + 16)); + transpose8x2_epi32(_r0, _r1); + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + p0 += max_jj * batch * 16; + pp += 32; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _tmp0); + _mm_storeu_si128((__m128i*)(pp + 8), _tmp1); + p0 += max_jj * batch * 8; + pp += 16; + } + p0 -= (b * max_jj + jj) * 8; +#endif // __SSE2__ + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + p0 += max_jj * batch * 2; + pp += 4; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch; + pp += 2; + } + } + for (; jj < max_jj; jj++) + { + const short* p0 = B; + + int kk = 0; +#if __SSE2__ +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)p0); + _mm256_storeu_si256((__m256i*)pp, _r0); + p0 += max_jj * batch * 16; + pp += 16; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + _mm_storeu_si128((__m128i*)pp, _r0); + p0 += max_jj * batch * 8; + pp += 8; + } + p0 -= (b * max_jj + jj) * 8; +#endif // __SSE2__ + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch * 2; + pp += 2; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += max_jj * batch; + pp += 1; + } + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& top_blob, int batch, int max_ii, int max_jj, int k, int max_kk, bool k_end) +{ + int* outptr = top_blob; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + __m512i _sum8; + __m512i _sum9; + __m512i _suma; + __m512i _sumb; + __m512i _sumc; + __m512i _sumd; + __m512i _sume; + __m512i _sumf; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + _sum8 = _mm512_setzero_si512(); + _sum9 = _mm512_setzero_si512(); + _suma = _mm512_setzero_si512(); + _sumb = _mm512_setzero_si512(); + _sumc = _mm512_setzero_si512(); + _sumd = _mm512_setzero_si512(); + _sume = _mm512_setzero_si512(); + _sumf = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + _sum8 = _mm512_load_si512((const __m512i*)(outptr + 128)); + _sum9 = _mm512_load_si512((const __m512i*)(outptr + 128 + 16)); + _suma = _mm512_load_si512((const __m512i*)(outptr + 128 + 32)); + _sumb = _mm512_load_si512((const __m512i*)(outptr + 128 + 48)); + _sumc = _mm512_load_si512((const __m512i*)(outptr + 128 + 64)); + _sumd = _mm512_load_si512((const __m512i*)(outptr + 128 + 80)); + _sume = _mm512_load_si512((const __m512i*)(outptr + 128 + 96)); + _sumf = _mm512_load_si512((const __m512i*)(outptr + 128 + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA1 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 1, 0, 3)); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pA3 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); + _sum8 = _mm512_dpwssd_epi32(_sum8, _pA2, _pB0); + _sum9 = _mm512_dpwssd_epi32(_sum9, _pA2, _pB1); + _suma = _mm512_dpwssd_epi32(_suma, _pA2, _pB2); + _sumb = _mm512_dpwssd_epi32(_sumb, _pA2, _pB3); + _sumc = _mm512_dpwssd_epi32(_sumc, _pA3, _pB0); + _sumd = _mm512_dpwssd_epi32(_sumd, _pA3, _pB1); + _sume = _mm512_dpwssd_epi32(_sume, _pA3, _pB2); + _sumf = _mm512_dpwssd_epi32(_sumf, _pA3, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); + _sum8 = _mm512_add_epi32(_sum8, _mm512_madd_epi16(_pA2, _pB0)); + _sum9 = _mm512_add_epi32(_sum9, _mm512_madd_epi16(_pA2, _pB1)); + _suma = _mm512_add_epi32(_suma, _mm512_madd_epi16(_pA2, _pB2)); + _sumb = _mm512_add_epi32(_sumb, _mm512_madd_epi16(_pA2, _pB3)); + _sumc = _mm512_add_epi32(_sumc, _mm512_madd_epi16(_pA3, _pB0)); + _sumd = _mm512_add_epi32(_sumd, _mm512_madd_epi16(_pA3, _pB1)); + _sume = _mm512_add_epi32(_sume, _mm512_madd_epi16(_pA3, _pB2)); + _sumf = _mm512_add_epi32(_sumf, _mm512_madd_epi16(_pA3, _pB3)); +#endif + + pA += 32; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pB)); + + __m512i _pA1 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 1, 0, 3)); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pA3 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA0, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA0, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA1, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA1, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA1, _pB3); + __m512i _s8 = _mm512_mullo_epi32(_pA2, _pB0); + __m512i _s9 = _mm512_mullo_epi32(_pA2, _pB1); + __m512i _sa = _mm512_mullo_epi32(_pA2, _pB2); + __m512i _sb = _mm512_mullo_epi32(_pA2, _pB3); + __m512i _sc = _mm512_mullo_epi32(_pA3, _pB0); + __m512i _sd = _mm512_mullo_epi32(_pA3, _pB1); + __m512i _se = _mm512_mullo_epi32(_pA3, _pB2); + __m512i _sf = _mm512_mullo_epi32(_pA3, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + _sum8 = _mm512_add_epi32(_sum8, _s8); + _sum9 = _mm512_add_epi32(_sum9, _s9); + _suma = _mm512_add_epi32(_suma, _sa); + _sumb = _mm512_add_epi32(_sumb, _sb); + _sumc = _mm512_add_epi32(_sumc, _sc); + _sumd = _mm512_add_epi32(_sumd, _sd); + _sume = _mm512_add_epi32(_sume, _se); + _sumf = _mm512_add_epi32(_sumf, _sf); + + pA += 16; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff + // 01 12 23 30 45 56 67 74 89 9a ab b8 cd de ef fc + // 02 13 20 31 46 57 64 75 8a 9b a8 b9 ce df ec fd + // 03 10 21 32 47 54 65 76 8b 98 a9 ba cf dc ed fe + // c0 d1 e2 f3 04 15 26 37 48 59 6a 7b 8c 9d ae bf + // c1 d2 e3 f0 05 16 27 34 49 5a 6b 78 8d 9e af bc + // c2 d3 e0 f1 06 17 24 35 4a 5b 68 79 8e 9f ac bd + // c3 d0 e1 f2 07 14 25 36 4b 58 69 7a 8f 9c ad be + // 80 91 a2 b3 c4 d5 e6 f7 08 19 2a 3b 4c 5d 6e 7f + // 81 92 a3 b0 c5 d6 e7 f4 09 1a 2b 38 4d 5e 6f 7c + // 82 93 a0 b1 c6 d7 e4 f5 0a 1b 28 39 4e 5f 6c 7d + // 83 90 a1 b2 c7 d4 e5 f6 0b 18 29 3a 4f 5c 6d 7e + // 40 51 62 73 84 95 a6 b7 c8 d9 ea fb 0c 1d 2e 3f + // 41 52 63 70 85 96 a7 b4 c9 da eb f8 0d 1e 2f 3c + // 42 53 60 71 86 97 a4 b5 ca db e8 f9 0e 1f 2c 3d + // 43 50 61 72 87 94 a5 b6 cb d8 e9 fa 0f 1c 2d 3e + // to + // 00 10 20 30 44 54 64 74 88 98 a8 b8 cc dc ec fc + // 01 11 21 31 45 55 65 75 89 99 a9 b9 cd dd ed fd + // 02 12 22 32 46 56 66 76 8a 9a aa ba ce de ee fe + // 03 13 23 33 47 57 67 77 8b 9b ab bb cf df ef ff + // c0 d0 e0 f0 04 14 24 34 48 58 68 78 8c 9c ac bc + // c1 d1 e1 f1 05 15 25 35 49 59 69 79 8d 9d ad bd + // c2 d2 e2 f2 06 16 26 36 4a 5a 6a 7a 8e 9e ae be + // c3 d3 e3 f3 07 17 27 37 4b 5b 6b 7b 8f 9f af bf + // 80 90 a0 b0 c4 d4 e4 f4 08 18 28 38 4c 5c 6c 7c + // 81 91 a1 b1 c5 d5 e5 f5 09 19 29 39 4d 5d 6d 7d + // 82 92 a2 b2 c6 d6 e6 f6 0a 1a 2a 3a 4e 5e 6e 7e + // 83 93 a3 b3 c7 d7 e7 f7 0b 1b 2b 3b 4f 5f 6f 7f + // 40 50 60 70 84 94 a4 b4 c8 d8 e8 f8 0c 1c 2c 3c + // 41 51 61 71 85 95 a5 b5 c9 d9 e9 f9 0d 1d 2d 3d + // 42 52 62 72 86 96 a6 b6 ca da ea fa 0e 1e 2e 3e + // 43 53 63 73 87 97 a7 b7 cb db eb fb 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + _sum9 = _mm512_shuffle_epi32(_sum9, _MM_PERM_CBAD); + _suma = _mm512_shuffle_epi32(_suma, _MM_PERM_BADC); + _sumb = _mm512_shuffle_epi32(_sumb, _MM_PERM_ADCB); + _sumd = _mm512_shuffle_epi32(_sumd, _MM_PERM_CBAD); + _sume = _mm512_shuffle_epi32(_sume, _MM_PERM_BADC); + _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + __m512i _tmp8 = _mm512_unpacklo_epi32(_sum8, _sumb); + __m512i _tmp9 = _mm512_unpackhi_epi32(_sum8, _sumb); + __m512i _tmpa = _mm512_unpacklo_epi32(_suma, _sum9); + __m512i _tmpb = _mm512_unpackhi_epi32(_suma, _sum9); + __m512i _tmpc = _mm512_unpacklo_epi32(_sumc, _sumf); + __m512i _tmpd = _mm512_unpackhi_epi32(_sumc, _sumf); + __m512i _tmpe = _mm512_unpacklo_epi32(_sume, _sumd); + __m512i _tmpf = _mm512_unpackhi_epi32(_sume, _sumd); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum8 = _mm512_unpacklo_epi64(_tmp8, _tmpa); + _sum9 = _mm512_unpackhi_epi64(_tmp8, _tmpa); + _suma = _mm512_unpacklo_epi64(_tmpb, _tmp9); + _sumb = _mm512_unpackhi_epi64(_tmpb, _tmp9); + _sumc = _mm512_unpacklo_epi64(_tmpc, _tmpe); + _sumd = _mm512_unpackhi_epi64(_tmpc, _tmpe); + _sume = _mm512_unpacklo_epi64(_tmpf, _tmpd); + _sumf = _mm512_unpackhi_epi64(_tmpf, _tmpd); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + _sum9 = _mm512_shuffle_epi32(_sum9, _MM_PERM_CBAD); + _sumb = _mm512_shuffle_epi32(_sumb, _MM_PERM_CBAD); + _sumd = _mm512_shuffle_epi32(_sumd, _MM_PERM_CBAD); + _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sumc, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum8, _sum4, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sumc, _sum8, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum1, _sumd, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum5, _sum1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum9, _sum5, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sumd, _sum9, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmp8 = _mm512_shuffle_i32x4(_sum2, _sume, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp9 = _mm512_shuffle_i32x4(_sum6, _sum2, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmpa = _mm512_shuffle_i32x4(_suma, _sum6, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmpb = _mm512_shuffle_i32x4(_sume, _suma, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmpc = _mm512_shuffle_i32x4(_sum3, _sumf, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmpd = _mm512_shuffle_i32x4(_sum7, _sum3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmpe = _mm512_shuffle_i32x4(_sumb, _sum7, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmpf = _mm512_shuffle_i32x4(_sumf, _sumb, _MM_SHUFFLE(1, 3, 1, 3)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp8, _tmpa, _MM_SHUFFLE(3, 1, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmpc, _tmpe, _MM_SHUFFLE(3, 1, 2, 0)); + _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp9, _tmpb, _MM_SHUFFLE(3, 1, 2, 0)); + _sum7 = _mm512_shuffle_i32x4(_tmpd, _tmpf, _MM_SHUFFLE(3, 1, 2, 0)); + _sum8 = _mm512_shuffle_i32x4(_tmp2, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum9 = _mm512_shuffle_i32x4(_tmp6, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _suma = _mm512_shuffle_i32x4(_tmpa, _tmp8, _MM_SHUFFLE(3, 1, 2, 0)); + _sumb = _mm512_shuffle_i32x4(_tmpe, _tmpc, _MM_SHUFFLE(3, 1, 2, 0)); + _sumc = _mm512_shuffle_i32x4(_tmp3, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _sumd = _mm512_shuffle_i32x4(_tmp7, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _sume = _mm512_shuffle_i32x4(_tmpb, _tmp9, _MM_SHUFFLE(3, 1, 2, 0)); + _sumf = _mm512_shuffle_i32x4(_tmpf, _tmpd, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + _mm512_store_si512((__m512i*)(outptr + 128), _sum8); + _mm512_store_si512((__m512i*)(outptr + 128 + 16), _sum9); + _mm512_store_si512((__m512i*)(outptr + 128 + 32), _suma); + _mm512_store_si512((__m512i*)(outptr + 128 + 48), _sumb); + _mm512_store_si512((__m512i*)(outptr + 128 + 64), _sumc); + _mm512_store_si512((__m512i*)(outptr + 128 + 80), _sumd); + _mm512_store_si512((__m512i*)(outptr + 128 + 96), _sume); + _mm512_store_si512((__m512i*)(outptr + 128 + 112), _sumf); + outptr += 256; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 16 * 2)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 16 * 3)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 16 * 4)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 16 * 5)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 16 * 6)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 16 * 7)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); +#endif + + pA += 32; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m256i _pB = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pB)); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA0, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA0, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA1, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA1, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA1, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + + pA += 16; + pB += 8; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 80 91 a2 b3 c4 d5 e6 f7 + // 01 12 23 30 45 56 67 74 81 92 a3 b0 c5 d6 e7 f4 + // 02 13 20 31 46 57 64 75 82 93 a0 b1 c6 d7 e4 f5 + // 03 10 21 32 47 54 65 76 83 90 a1 b2 c7 d4 e5 f6 + // 40 51 62 73 04 15 26 37 c0 d1 e2 f3 84 95 a6 b7 + // 41 52 63 70 05 16 27 34 c1 d2 e3 f0 85 96 a7 b4 + // 42 53 60 71 06 17 24 35 c2 d3 e0 f1 86 97 a4 b5 + // 43 50 61 72 07 14 25 36 c3 d0 e1 f2 87 94 a5 b6 + // to + // 00 10 20 30 44 54 64 74 80 90 a0 b0 c4 d4 e4 f4 + // 01 11 21 31 45 55 65 75 81 91 a1 b1 c5 d5 e5 f5 + // 02 12 22 32 46 56 66 76 82 92 a2 b2 c6 d6 e6 f6 + // 03 13 23 33 47 57 67 77 83 93 a3 b3 c7 d7 e7 f7 + // 40 50 60 70 04 14 24 34 c0 d0 e0 f0 84 94 a4 b4 + // 41 51 61 71 05 15 25 35 c1 d1 e1 f1 85 95 a5 b5 + // 42 52 62 72 06 16 26 36 c2 d2 e2 f2 86 96 a6 b6 + // 43 53 63 73 07 17 27 37 c3 d3 e3 f3 87 97 a7 b7 + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + } + + // TODO + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum5, _sum1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum6, _sum2, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sum7, _sum3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _sum4 = _mm512_shuffle_i32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _sum7 = _mm512_shuffle_i32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 16 * 2), _sum2); + _mm512_store_si512((__m512i*)(outptr + 16 * 3), _sum3); + _mm512_store_si512((__m512i*)(outptr + 16 * 4), _sum4); + _mm512_store_si512((__m512i*)(outptr + 16 * 5), _sum5); + _mm512_store_si512((__m512i*)(outptr + 16 * 6), _sum6); + _mm512_store_si512((__m512i*)(outptr + 16 * 7), _sum7); + outptr += 16 * 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 16 * 2)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 16 * 3)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB = _mm512_castsi128_si512(_mm_loadu_si128((const __m128i*)pB)); + __m512i _pB0 = _mm512_shuffle_i32x4(_pB, _pB, _MM_SHUFFLE(0, 0, 0, 0)); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); +#endif + + pA += 32; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB))); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s3 = _mm512_mullo_epi32(_pA1, _pB1); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + pA += 16; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 40 51 62 73 80 91 a2 b3 c0 d1 e2 f3 + // 01 12 23 30 41 52 63 70 81 92 a3 b0 c1 d2 e3 f0 + // 20 31 02 13 60 71 42 53 a0 b1 82 93 e0 f1 c2 d3 + // 21 32 03 10 61 72 43 50 a1 b2 83 90 e1 f2 c3 d0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 16 * 2), _sum2); + _mm512_store_si512((__m512i*)(outptr + 16 * 3), _sum3); + outptr += 16 * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pB)[0])); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA, _pB1)); +#endif + + pA += 32; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_set1_epi32(((const int*)pB)[0])); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ABAB); + + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA, _pB1); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + pA += 16; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 40 51 60 71 80 91 a0 b1 c0 d1 e0 f1 + // 01 10 21 30 41 50 61 70 81 90 a1 b0 c1 d0 e1 f0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + { + __m512i _tmp0 = _mm512_shuffle_epi32(_sum0, _MM_PERM_DBCA); + __m512i _tmp1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_ACDB); + _sum0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + } + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + outptr += 16 * 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB = _mm512_set1_epi32(((const int*)pB)[0]); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA, _pB); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA, _pB)); +#endif + + pA += 32; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB = _mm512_set1_epi32(pB[0]); + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB); + _sum0 = _mm512_add_epi32(_sum0, _s0); + + pA += 16; + pB += 1; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + outptr += 16; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_loadu_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_loadu_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_loadu_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_loadu_si512((const __m512i*)(outptr + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA00, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA00, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA11, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA11, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA11, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA11, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA00, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA00, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA11, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA11, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA11, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA11, _pB3)); +#endif // __AVX512VNNI__ + + pA += 16; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA00, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA00, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA00, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA00, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA11, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA11, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA11, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA11, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + + pA += 8; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 08 19 2a 3b 4c 5d 6e 7f + // 01 12 23 30 45 56 67 74 09 1a 2b 38 4d 5e 6f 7c + // 02 13 20 31 46 57 64 75 0a 1b 28 39 4e 5f 6c 7d + // 03 10 21 32 47 54 65 76 0b 18 29 3a 4f 5c 6d 7e + // 40 51 62 73 04 15 26 37 48 59 6a 7b 0c 1d 2e 3f + // 41 52 63 70 05 16 27 34 49 5a 6b 78 0d 1e 2f 3c + // 42 53 60 71 06 17 24 35 4a 5b 68 79 0e 1f 2c 3d + // 43 50 61 72 07 14 25 36 4b 58 69 7a 0f 1c 2d 3e + // to + // 00 10 20 30 44 54 64 74 08 18 28 38 4c 5c 6c 7c + // 01 11 21 31 45 55 65 75 09 19 29 39 4d 5d 6d 7d + // 02 12 22 32 46 56 66 76 0a 1a 2a 3a 4e 5e 6e 7e + // 03 13 23 33 47 57 67 77 0b 1b 2b 3b 4f 5f 6f 7f + // 40 50 60 70 04 14 24 34 48 58 68 78 0c 1c 2c 3c + // 41 51 61 71 05 15 25 35 49 59 69 79 0d 1d 2d 3d + // 42 52 62 72 06 16 26 36 4a 5a 6a 7a 0e 1e 2e 3e + // 43 53 63 73 07 17 27 37 4b 5b 6b 7b 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(1, 3, 1, 3)); + _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); + _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + _mm512_storeu_si512((__m512i*)(outptr + 64), _sum4); + _mm512_storeu_si512((__m512i*)(outptr + 80), _sum5); + _mm512_storeu_si512((__m512i*)(outptr + 96), _sum6); + _mm512_storeu_si512((__m512i*)(outptr + 112), _sum7); + outptr += 128; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX512F__ + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; +#else + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; + __m256i _sum4; + __m256i _sum5; + __m256i _sum6; + __m256i _sum7; +#endif // __AVX512F__ + + if (k == 0) + { +#if __AVX512F__ + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); +#else + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + _sum4 = _mm256_setzero_si256(); + _sum5 = _mm256_setzero_si256(); + _sum6 = _mm256_setzero_si256(); + _sum7 = _mm256_setzero_si256(); +#endif // __AVX512F__ + } + else + { +#if __AVX512F__ + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); +#else + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); + _sum4 = _mm256_load_si256((const __m256i*)(outptr + 32)); + _sum5 = _mm256_load_si256((const __m256i*)(outptr + 40)); + _sum6 = _mm256_load_si256((const __m256i*)(outptr + 48)); + _sum7 = _mm256_load_si256((const __m256i*)(outptr + 56)); +#endif // __AVX512F__ + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); +#if __AVX512F__ + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); + __m512i _pB23 = _mm512_shuffle_epi32(_pB01, _MM_PERM_BADC); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB01); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB23); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA11, _pB01); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA11, _pB23); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB01)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB23)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB01)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB23)); +#endif // __AVX512VNNI__ +#else // __AVX512F__ + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + +#if __AVXVNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm256_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm256_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm256_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm256_dpwssd_epi32(_sum7, _pA1, _pB3); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA0, _pB2)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA0, _pB3)); + _sum4 = _mm256_add_epi32(_sum4, _mm256_madd_epi16(_pA1, _pB0)); + _sum5 = _mm256_add_epi32(_sum5, _mm256_madd_epi16(_pA1, _pB1)); + _sum6 = _mm256_add_epi32(_sum6, _mm256_madd_epi16(_pA1, _pB2)); + _sum7 = _mm256_add_epi32(_sum7, _mm256_madd_epi16(_pA1, _pB3)); +#endif // __AVXVNNI__ +#endif // __AVX512F__ + + pA += 16; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); +#if __AVX512F__ + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_shuffle_i32x4(_pA00, _pA00, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); + __m512i _pB23 = _mm512_permutex_epi64(_pB01, _MM_SHUFFLE(2, 3, 0, 1)); + + __m512i _s01 = _mm512_mullo_epi32(_pA00, _pB01); + __m512i _s23 = _mm512_mullo_epi32(_pA00, _pB23); + __m512i _s45 = _mm512_mullo_epi32(_pA11, _pB01); + __m512i _s67 = _mm512_mullo_epi32(_pA11, _pB23); + _sum0 = _mm512_add_epi32(_sum0, _s01); + _sum1 = _mm512_add_epi32(_sum1, _s23); + _sum2 = _mm512_add_epi32(_sum2, _s45); + _sum3 = _mm512_add_epi32(_sum3, _s67); +#else + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA0, _pB2); + __m256i _s3 = _mm256_mullo_epi32(_pA0, _pB3); + __m256i _s4 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s5 = _mm256_mullo_epi32(_pA1, _pB1); + __m256i _s6 = _mm256_mullo_epi32(_pA1, _pB2); + __m256i _s7 = _mm256_mullo_epi32(_pA1, _pB3); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); + _sum4 = _mm256_add_epi32(_sum4, _s4); + _sum5 = _mm256_add_epi32(_sum5, _s5); + _sum6 = _mm256_add_epi32(_sum6, _s6); + _sum7 = _mm256_add_epi32(_sum7, _s7); +#endif // __AVX512F__ + + pA += 8; + pB += 8; + } + +#if __AVX512F__ + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 01 12 23 30 45 56 67 74 + // 02 13 20 31 46 57 64 75 03 10 21 32 47 54 65 76 + // 40 51 62 73 04 15 26 37 41 52 63 70 05 16 27 34 + // 42 53 60 71 06 17 24 35 43 50 61 72 07 14 25 36 + // to + // 00 10 20 30 44 54 64 74 04 14 24 34 40 50 60 70 + // 01 11 21 31 45 55 65 75 05 15 25 35 41 51 61 71 + // 02 12 22 32 46 56 66 76 06 16 26 36 42 52 62 72 + // 03 13 23 33 47 57 67 77 07 17 27 37 43 53 63 73 + { + __m512i _s0 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(0, 1, 1, 0)); + __m512i _s1 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(2, 3, 3, 2)); + __m512i _s2 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(0, 1, 1, 0)); + __m512i _s3 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(2, 3, 3, 2)); + _s1 = _mm512_shuffle_epi32(_s1, _MM_PERM_ADCB); + _s2 = _mm512_shuffle_epi32(_s2, _MM_PERM_BADC); + _s3 = _mm512_shuffle_epi32(_s3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_s0, _s1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_s0, _s1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_s2, _s3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_s2, _s3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(3, 0, 3, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(3, 0, 3, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(1, 2, 1, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(1, 2, 1, 2)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; +#else + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 + // 01 12 23 30 45 56 67 74 + // 02 13 20 31 46 57 64 75 + // 03 10 21 32 47 54 65 76 + // 40 51 62 73 04 15 26 37 + // 41 52 63 70 05 16 27 34 + // 42 53 60 71 06 17 24 35 + // 43 50 61 72 07 14 25 36 + // to + // 00 10 20 30 44 54 64 74 + // 01 11 21 31 45 55 65 75 + // 02 12 22 32 46 56 66 76 + // 03 13 23 33 47 57 67 77 + // 40 50 60 70 04 14 24 34 + // 41 51 61 71 05 15 25 35 + // 42 52 62 72 06 16 26 36 + // 43 53 63 73 07 17 27 37 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum2 = _mm256_shuffle_epi32(_sum2, _MM_SHUFFLE(1, 0, 3, 2)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(0, 3, 2, 1)); + _sum5 = _mm256_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum6 = _mm256_shuffle_epi32(_sum6, _MM_SHUFFLE(1, 0, 3, 2)); + _sum7 = _mm256_shuffle_epi32(_sum7, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + __m256i _tmp4 = _mm256_unpacklo_epi32(_sum4, _sum7); + __m256i _tmp5 = _mm256_unpackhi_epi32(_sum4, _sum7); + __m256i _tmp6 = _mm256_unpacklo_epi32(_sum6, _sum5); + __m256i _tmp7 = _mm256_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm256_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm256_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm256_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm256_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum5 = _mm256_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm256_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + } + + __m256i _tmp0 = _mm256_permute2x128_si256(_sum0, _sum4, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp1 = _mm256_permute2x128_si256(_sum1, _sum5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp2 = _mm256_permute2x128_si256(_sum2, _sum6, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp3 = _mm256_permute2x128_si256(_sum3, _sum7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp4 = _mm256_permute2x128_si256(_sum4, _sum0, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp5 = _mm256_permute2x128_si256(_sum5, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp6 = _mm256_permute2x128_si256(_sum6, _sum2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp7 = _mm256_permute2x128_si256(_sum7, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + _sum4 = _tmp4; + _sum5 = _tmp5; + _sum6 = _tmp6; + _sum7 = _tmp7; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 8 * 2), _sum2); + _mm256_store_si256((__m256i*)(outptr + 8 * 3), _sum3); + _mm256_store_si256((__m256i*)(outptr + 8 * 4), _sum4); + _mm256_store_si256((__m256i*)(outptr + 8 * 5), _sum5); + _mm256_store_si256((__m256i*)(outptr + 8 * 6), _sum6); + _mm256_store_si256((__m256i*)(outptr + 8 * 7), _sum7); + outptr += 8 * 8; +#endif // __AVX512F__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); +#endif + + pA += 16; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m256i _pA0 = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB0 = _mm256_cvtepi16_epi32(_mm_castpd_si128(_mm_load1_pd((const double*)pB))); + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s3 = _mm256_mullo_epi32(_pA1, _pB1); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); + + pA += 8; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 40 51 62 73 + // 01 12 23 30 41 52 63 70 + // 20 31 02 13 60 71 42 53 + // 21 32 03 10 61 72 43 50 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 16), _sum2); + _mm256_store_si256((__m256i*)(outptr + 24), _sum3); + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA, _pB1)); +#endif + + pA += 16; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB0 = _mm256_cvtepi16_epi32(_mm_castps_si128(_mm_load1_ps((const float*)pB))); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA, _pB1); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + + pA += 8; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 40 51 60 71 + // 01 10 21 30 41 50 61 70 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + { + __m256i _tmp0 = _mm256_shuffle_epi32(_sum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _tmp1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 2, 3, 1)); + _sum0 = _mm256_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm256_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + outptr += 16; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m256i _sum0; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA, _pB)); + + pA += 16; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB = _mm256_set1_epi32(pB[0]); + __m256i _s0 = _mm256_mullo_epi32(_pA, _pB); + _sum0 = _mm256_add_epi32(_sum0, _s0); + + pA += 8; + pB += 1; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + outptr += 8; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m256i _pAA = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m512i _pA0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pAA), _pAA, 1); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); +#endif + + pA += 8; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi16_epi32(_pA); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s3 = _mm512_mullo_epi32(_pA1, _pB1); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + pA += 4; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 08 19 2a 3b 0c 1d 2e 3f + // 01 12 23 30 05 16 27 34 09 1a 2b 38 0d 1e 2f 3c + // 20 31 02 13 24 35 06 17 28 3a 0a 1b 2c 3d 0e 1f + // 21 32 03 10 25 36 07 14 29 3a 0b 18 2d 3e 0f 1c + // to + // 00 10 20 30 04 14 24 34 08 18 28 38 0c 1c 2c 3c + // 01 11 21 31 05 15 25 35 09 19 29 39 0d 1d 2d 3d + // 02 12 22 32 06 16 26 36 0a 1a 2a 3a 0e 1e 2e 3e + // 03 13 23 33 07 17 27 37 0b 1b 2b 3b 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; +#else + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; + __m128i _sum4; + __m128i _sum5; + __m128i _sum6; + __m128i _sum7; +#endif + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); +#else + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + _sum4 = _mm_setzero_si128(); + _sum5 = _mm_setzero_si128(); + _sum6 = _mm_setzero_si128(); + _sum7 = _mm_setzero_si128(); +#endif + } + else + { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); + _sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_loadu_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_loadu_si256((const __m256i*)(outptr + 24)); +#else + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); + _sum4 = _mm_load_si128((const __m128i*)(outptr + 16)); + _sum5 = _mm_load_si128((const __m128i*)(outptr + 20)); + _sum6 = _mm_load_si128((const __m128i*)(outptr + 24)); + _sum7 = _mm_load_si128((const __m128i*)(outptr + 28)); +#endif + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); +#endif +#else // __AVX2__ + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maddd_epi16(_pA0, _pB2, _sum2); + _sum3 = _mm_maddd_epi16(_pA0, _pB3, _sum3); + _sum4 = _mm_maddd_epi16(_pA1, _pB0, _sum4); + _sum5 = _mm_maddd_epi16(_pA1, _pB1, _sum5); + _sum6 = _mm_maddd_epi16(_pA1, _pB2, _sum6); + _sum7 = _mm_maddd_epi16(_pA1, _pB3, _sum7); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA0, _pB2)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA0, _pB3)); + _sum4 = _mm_add_epi32(_sum4, _mm_madd_epi16(_pA1, _pB0)); + _sum5 = _mm_add_epi32(_sum5, _mm_madd_epi16(_pA1, _pB1)); + _sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2)); + _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); +#endif +#endif // __AVX2__ + + pA += 8; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + +#if __AVX2__ + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s3 = _mm256_mullo_epi32(_pA1, _pB1); + + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); +#else // __AVX2__ +#if __XOP__ + __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_unpackhi_epi16(_pB, _pB); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maccd_epi16(_pA0, _pB2, _sum2); + _sum3 = _mm_maccd_epi16(_pA0, _pB3, _sum3); + _sum4 = _mm_maccd_epi16(_pA1, _pB0, _sum4); + _sum5 = _mm_maccd_epi16(_pA1, _pB1, _sum5); + _sum6 = _mm_maccd_epi16(_pA1, _pB2, _sum6); + _sum7 = _mm_maccd_epi16(_pA1, _pB3, _sum7); +#else + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i _pB01 = _pB; + __m128i _pB23 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB01); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB01); + __m128i _sl1 = _mm_mullo_epi16(_pA0, _pB23); + __m128i _sh1 = _mm_mulhi_epi16(_pA0, _pB23); + __m128i _sl2 = _mm_mullo_epi16(_pA1, _pB01); + __m128i _sh2 = _mm_mulhi_epi16(_pA1, _pB01); + __m128i _sl3 = _mm_mullo_epi16(_pA1, _pB23); + __m128i _sh3 = _mm_mulhi_epi16(_pA1, _pB23); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + __m128i _s4 = _mm_unpacklo_epi16(_sl2, _sh2); + __m128i _s5 = _mm_unpackhi_epi16(_sl2, _sh2); + __m128i _s6 = _mm_unpacklo_epi16(_sl3, _sh3); + __m128i _s7 = _mm_unpackhi_epi16(_sl3, _sh3); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); + _sum4 = _mm_add_epi32(_sum4, _s4); + _sum5 = _mm_add_epi32(_sum5, _s5); + _sum6 = _mm_add_epi32(_sum6, _s6); + _sum7 = _mm_add_epi32(_sum7, _s7); +#endif +#endif // __AVX2__ + + pA += 4; + pB += 8; + } + +#if __AVX2__ + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + // to + // 00 10 20 30 04 14 24 34 + // 01 11 21 31 05 15 25 35 + // 02 12 22 32 06 16 26 36 + // 03 13 23 33 07 17 27 37 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp0 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + _tmp1 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + _tmp2 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + _tmp3 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + } + + _mm256_storeu_si256((__m256i*)outptr, _sum0); + _mm256_storeu_si256((__m256i*)(outptr + 8), _sum1); + _mm256_storeu_si256((__m256i*)(outptr + 16), _sum2); + _mm256_storeu_si256((__m256i*)(outptr + 24), _sum3); + outptr += 32; +#else // __AVX2__ + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + // to + // 00 10 20 30 04 14 24 34 + // 01 11 21 31 05 15 25 35 + // 02 12 22 32 06 16 26 36 + // 03 13 23 33 07 17 27 37 + { + _sum2 = _mm_shuffle_epi32(_sum2, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum6 = _mm_shuffle_epi32(_sum6, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum6); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum6); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum7); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum7); + __m128i _tmp4 = _mm_unpacklo_epi32(_sum4, _sum2); + __m128i _tmp5 = _mm_unpackhi_epi32(_sum4, _sum2); + __m128i _tmp6 = _mm_unpacklo_epi32(_sum5, _sum3); + __m128i _tmp7 = _mm_unpackhi_epi32(_sum5, _sum3); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp4); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp4); + _sum2 = _mm_unpacklo_epi64(_tmp5, _tmp1); + _sum3 = _mm_unpackhi_epi64(_tmp5, _tmp1); + _sum4 = _mm_unpacklo_epi64(_tmp2, _tmp6); + _sum5 = _mm_unpackhi_epi64(_tmp2, _tmp6); + _sum6 = _mm_unpacklo_epi64(_tmp7, _tmp3); + _sum7 = _mm_unpackhi_epi64(_tmp7, _tmp3); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum5 = _mm_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); + _mm_store_si128((__m128i*)(outptr + 16), _sum4); + _mm_store_si128((__m128i*)(outptr + 20), _sum5); + _mm_store_si128((__m128i*)(outptr + 24), _sum6); + _mm_store_si128((__m128i*)(outptr + 28), _sum7); + outptr += 32; +#endif // __AVX2__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); + _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); +#endif + + pA += 8; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); +#if __XOP__ + __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maccd_epi16(_pA1, _pB0, _sum2); + _sum3 = _mm_maccd_epi16(_pA1, _pB1, _sum3); +#else + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB01); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB01); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB01); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif + + pA += 4; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 + // 01 12 23 30 + // 20 31 02 13 + // 21 32 03 10 + // to + // 00 10 20 30 + // 01 11 21 31 + // 02 12 22 32 + // 03 13 23 33 + { + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum3); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum3); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum2, _sum1); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); +#endif + + pA += 8; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __XOP__ + _pA = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + _sum0 = _mm_maccd_epi16(_pA, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA, _pB1, _sum1); +#else + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 1, 0, 1)); + __m128i _sl = _mm_mullo_epi16(_pA, _pB01); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); +#endif + + pA += 4; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 + // 01 10 21 30 + // to + // 00 10 20 30 + // 01 11 21 31 + { + __m128i _tmp0 = _mm_shuffle_epi32(_sum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _tmp1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 2, 3, 1)); + _sum0 = _mm_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + outptr += 8; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m128i _sum0; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA, _pB, _sum0); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); +#endif + + pA += 8; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_set1_epi16(pB[0]); + +#if __XOP__ + _pA = _mm_unpacklo_epi16(_pA, _pA); + _sum0 = _mm_maccd_epi16(_pA, _pB, _sum0); +#else + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); +#endif + + pA += 4; + pB += 1; + } + + _mm_store_si128((__m128i*)outptr, _sum0); + outptr += 4; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pA1 = _mm512_set1_epi32(((const int*)pA)[1]); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA1, _pB0); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA1, _pB0)); +#endif // __AVX512VNNI__ + + pA += 4; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_set1_epi32(pA[0]); + __m512i _pA1 = _mm512_set1_epi32(pA[1]); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA1, _pB0); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + pA += 2; + pB += 16; + } + + if (k_end) + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(1, 0, 1, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_sum0, _sum0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_sum1, _sum1, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + outptr += 32; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; +#else + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; +#endif + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); +#else + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); +#endif + } + else + { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); + _sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8)); +#else + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_loadu_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_loadu_si128((const __m128i*)(outptr + 12)); +#endif + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m256i _pA0 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pA)); + __m256i _pA1 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(pA + 2))); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + + // vs2019 internal compiler error with avx512 vnni intrinsics here + // fallback to avx2 madd anyway as a workaround --- nihui + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA1, _pB0)); +#else // __AVX2__ + __m128i _pA0 = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pA1 = _mm_castps_si128(_mm_load1_ps((const float*)(pA + 2))); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); +#endif // __AVX2__ + + pA += 4; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pB = _mm_load_si128((const __m128i*)pB); +#if __AVX2__ + __m256i _pA0 = _mm256_set1_epi32(pA[0]); + __m256i _pA1 = _mm256_set1_epi32(pA[1]); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA1, _pB0); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); +#else // __AVX2__ + __m128i _pA0 = _mm_set1_epi16(pA[0]); + __m128i _pA1 = _mm_set1_epi16(pA[1]); + + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif // __AVX2__ + pA += 2; + pB += 8; + } + +#if __AVX2__ + if (k_end) + { + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum1); + _sum0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _sum1 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); + } + + _mm256_storeu_si256((__m256i*)outptr, _sum0); + _mm256_storeu_si256((__m256i*)(outptr + 8), _sum1); + outptr += 16; +#else // __AVX2__ + if (k_end) + { + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum2); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum2); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum3); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum3); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + _mm_storeu_si128((__m128i*)(outptr + 8), _sum2); + _mm_storeu_si128((__m128i*)(outptr + 12), _sum3); + outptr += 16; +#endif // __AVX2__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA0 = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pA1 = _mm_castps_si128(_mm_load1_ps((const float*)(pA + 2))); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA1, _pB)); + + pA += 4; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA0 = _mm_set1_epi16(pA[0]); + __m128i _pA1 = _mm_set1_epi16(pA[1]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpacklo_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + pA += 2; + pB += 4; + } + + if (k_end) + { + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum1); + _sum0 = _tmp0; + _sum1 = _tmp1; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + outptr += 2 * 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum00 = 0; + int sum01 = 0; + int sum10 = 0; + int sum11 = 0; + + if (k == 0) + { + sum00 = 0; + sum01 = 0; + sum10 = 0; + sum11 = 0; + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum00 += pA[0] * pB[0]; + sum00 += pA[1] * pB[1]; + sum01 += pA[2] * pB[0]; + sum01 += pA[3] * pB[1]; + sum10 += pA[0] * pB[2]; + sum10 += pA[1] * pB[3]; + sum11 += pA[2] * pB[2]; + sum11 += pA[3] * pB[3]; + + pA += 4; + pB += 4; + } + for (; kk < max_kk; kk++) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + outptr += 2 * 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[2] * pB[0]; + sum1 += pA[3] * pB[1]; + pA += 4; + pB += 2; + } + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + } + } + for (; ii < max_ii; ii++) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); +#endif + + pA += 2; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_set1_epi32(pA[0]); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_loadu_si256((const __m256i*)pB)); + + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB0); + _sum0 = _mm512_add_epi32(_sum0, _s0); + + pA += 1; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + outptr += 16; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); + + pA += 2; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_set1_epi16(pA[0]); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + pA += 1; + pB += 8; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + outptr += 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); + + pA += 2; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_set1_epi16(pA[0]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + pA += 1; + pB += 4; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + outptr += 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[0] * pB[2]; + sum1 += pA[1] * pB[3]; + pA += 2; + pB += 4; + } + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum = 0; + + if (k == 0) + { + sum = 0; + } + else + { + sum = outptr[0]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + outptr[0] = sum; + outptr += 1; + } + } + } +} + +static void get_optimal_tile_mnk_int8(int M, int N, int K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size_int8 = (int)(get_cpu_level2_cache_size() / sizeof(short)); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + // solve M + { + int tile_size = (int)sqrt((float)l2_cache_size_int8 / 3); + +#if __AVX512F__ + TILE_M = std::max(16, tile_size / 16 * 16); +#elif __AVX2__ + TILE_M = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_M = std::max(4, tile_size / 4 * 4); +#else + TILE_M = std::max(2, tile_size / 2 * 2); +#endif + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + int nn_M = (M + TILE_M - 1) / TILE_M; +#if __AVX512F__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); +#elif __AVX2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + + if (nT > 1) + { +#if __AVX512F__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 15) / 16 * 16); +#elif __AVX2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); +#endif + } + } + + // solve K + { + int tile_size = (int)(sqrt((float)l2_cache_size_int8) - TILE_M); + +#if __AVX512F__ + TILE_K = std::max(16, tile_size / 16 * 16); +#elif __AVX2__ + TILE_K = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_K = std::max(4, tile_size / 4 * 4); +#else + TILE_K = std::max(2, tile_size / 2 * 2); +#endif + + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __AVX512F__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 15) / 16 * 16); +#elif __AVX2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __SSE2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); +#endif + } + + if (N > 0) + { + int tile_size = (int)((l2_cache_size_int8 - TILE_M * TILE_K) / (TILE_M * 2 + TILE_K)); + +#if __SSE2__ + TILE_N = std::max(4, tile_size / 4 * 4); +#else + TILE_N = std::max(1, tile_size); +#endif + + int nn_N = (N + TILE_N - 1) / TILE_N; +#if __SSE2__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif + } +} + +static inline void conv3x3s1_winograd23_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const signed char ktm[4][3] = { + // {2, 0, 0}, + // {1, 1, 1}, + // {1, -1, 1}, + // {0, 0, 2} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[4][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 2; + tmp[1][m] = r0 + r1 + r2; + tmp[2][m] = r0 - r1 + r2; + tmp[3][m] = r2 * 2; + + k0 += 3; + } + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 2; + short z1 = r0 + r1 + r2; + short z2 = r0 - r1 + r2; + short z3 = r2 * 2; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp += 4; + } + } + } +} + +static void conv3x3s1_winograd23_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd23_transform_kernel_int8_avx2(kernel, AT, inch, outch, opt); + return; + } +#endif +#endif + + const int M = outch; + const int K = inch; + const int B = 16; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 2u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd23_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd23_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const signed char itm[4][4] = { + // {1, 0, -1, 0}, + // {0, 1, 1, 0}, + // {0, -1, 1, 0}, + // {0, -1, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w - 1) / 2; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __SSE2__ +#if __AVX512F__ + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + short tmp[4][4][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_setzero_si256(); + __m256i _r1 = _mm256_setzero_si256(); + __m256i _r2 = _mm256_setzero_si256(); + __m256i _r3 = _mm256_setzero_si256(); + + if (ti * 2 + m < h) + { + if (elempack == 16) + { + _r0 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)r0)); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 16))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 32))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 48))); + } + if (elempack == 8) + { + const signed char* r1 = r0 + N; + + _r0 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)r0), _mm_loadl_epi64((const __m128i*)r1))); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 8)), _mm_loadl_epi64((const __m128i*)(r1 + 8)))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 16)), _mm_loadl_epi64((const __m128i*)(r1 + 16)))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 24)), _mm_loadl_epi64((const __m128i*)(r1 + 24)))); + } + if (elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _r0 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)r0, sizeof(signed char)))); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 1), sizeof(signed char)))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 2), sizeof(signed char)))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 3), sizeof(signed char)))); + } + } + + __m256i _tmp0 = _mm256_sub_epi16(_r0, _r2); + __m256i _tmp1 = _mm256_add_epi16(_r1, _r2); + __m256i _tmp2 = _mm256_sub_epi16(_r2, _r1); + __m256i _tmp3 = _mm256_sub_epi16(_r3, _r1); + + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 16; + short* p1 = p0 + max_jj * 16; + short* p2 = p0 + max_jj * 16 * 2; + short* p3 = p0 + max_jj * 16 * 3; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + + __m256i _tmp0 = _mm256_sub_epi16(_r0, _r2); + __m256i _tmp1 = _mm256_add_epi16(_r1, _r2); + __m256i _tmp2 = _mm256_sub_epi16(_r2, _r1); + __m256i _tmp3 = _mm256_sub_epi16(_r3, _r1); + + _mm256_store_si256((__m256i*)p0, _tmp0); + _mm256_store_si256((__m256i*)p1, _tmp1); + _mm256_store_si256((__m256i*)p2, _tmp2); + _mm256_store_si256((__m256i*)p3, _tmp3); + + p0 += max_jj * 4 * 16; + p1 += max_jj * 4 * 16; + p2 += max_jj * 4 * 16; + p3 += max_jj * 4 * 16; + } + } + } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) +#endif // __AVX512F__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + short tmp[4][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; + + for (int m = 0; m < 4; m++) + { + __m128i _r0 = _mm_setzero_si128(); + __m128i _r1 = _mm_setzero_si128(); + __m128i _r2 = _mm_setzero_si128(); + __m128i _r3 = _mm_setzero_si128(); + + if (ti * 2 + m < h) + { + if (elempack == 8) + { + _r0 = _mm_loadl_epi64((const __m128i*)r0); + _r0 = _mm_unpacklo_epi8(_r0, _mm_cmpgt_epi8(_mm_setzero_si128(), _r0)); + if (tj * 2 + 1 < w) + { + _r1 = _mm_loadl_epi64((const __m128i*)(r0 + 8)); + _r1 = _mm_unpacklo_epi8(_r1, _mm_cmpgt_epi8(_mm_setzero_si128(), _r1)); + } + if (tj * 2 + 2 < w) + { + _r2 = _mm_loadl_epi64((const __m128i*)(r0 + 16)); + _r2 = _mm_unpacklo_epi8(_r2, _mm_cmpgt_epi8(_mm_setzero_si128(), _r2)); + } + if (tj * 2 + 3 < w) + { + _r3 = _mm_loadl_epi64((const __m128i*)(r0 + 24)); + _r3 = _mm_unpacklo_epi8(_r3, _mm_cmpgt_epi8(_mm_setzero_si128(), _r3)); + } + } + if (elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); +#if __AVX512F__ + _r0 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)))); + if (tj * 2 + 1 < w) _r1 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)))); + if (tj * 2 + 2 < w) _r2 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)))); + if (tj * 2 + 3 < w) _r3 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)))); +#else + __m128i _sindex8 = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m256i _sindex88 = _mm256_inserti128_si256(_mm256_castsi128_si256(_sindex8), _sindex8, 1); + __m256i _val0_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)), _sindex88); + _r0 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val0_32, 0), _mm256_extracti128_si256(_val0_32, 1))); + if (tj * 2 + 1 < w) + { + __m256i _val1_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)), _sindex88); + _r1 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val1_32, 0), _mm256_extracti128_si256(_val1_32, 1))); + } + if (tj * 2 + 2 < w) + { + __m256i _val2_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)), _sindex88); + _r2 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val2_32, 0), _mm256_extracti128_si256(_val2_32, 1))); + } + if (tj * 2 + 3 < w) + { + __m256i _val3_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)), _sindex88); + _r3 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val3_32, 0), _mm256_extracti128_si256(_val3_32, 1))); + } +#endif // __AVX512F__ +#else // __AVX2__ + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + __m128i _t0 = _mm_loadl_epi64((const __m128i*)r0); + __m128i _t1 = _mm_loadl_epi64((const __m128i*)r1); + __m128i _t2 = _mm_loadl_epi64((const __m128i*)r2); + __m128i _t3 = _mm_loadl_epi64((const __m128i*)r3); + __m128i _t4 = _mm_loadl_epi64((const __m128i*)r4); + __m128i _t5 = _mm_loadl_epi64((const __m128i*)r5); + __m128i _t6 = _mm_loadl_epi64((const __m128i*)r6); + __m128i _t7 = _mm_loadl_epi64((const __m128i*)r7); + + __m128i _t01 = _mm_unpacklo_epi8(_t0, _t1); + __m128i _t23 = _mm_unpacklo_epi8(_t2, _t3); + __m128i _t45 = _mm_unpacklo_epi8(_t4, _t5); + __m128i _t67 = _mm_unpacklo_epi8(_t6, _t7); + _t0 = _mm_unpacklo_epi16(_t01, _t23); + _t1 = _mm_unpacklo_epi16(_t45, _t67); + _t2 = _mm_unpacklo_epi32(_t0, _t1); + _t3 = _mm_unpackhi_epi32(_t0, _t1); + + __m128i _extt2 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t2); + __m128i _extt3 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t3); + + _r0 = _mm_unpacklo_epi8(_t2, _extt2); + if (tj * 2 + 1 < w) _r1 = _mm_unpackhi_epi8(_t2, _extt2); + if (tj * 2 + 2 < w) _r2 = _mm_unpacklo_epi8(_t3, _extt3); + if (tj * 2 + 3 < w) _r3 = _mm_unpackhi_epi8(_t3, _extt3); +#endif // __AVX2__ + } + } + + __m128i _tmp0 = _mm_sub_epi16(_r0, _r2); + __m128i _tmp1 = _mm_add_epi16(_r1, _r2); + __m128i _tmp2 = _mm_sub_epi16(_r2, _r1); + __m128i _tmp3 = _mm_sub_epi16(_r3, _r1); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + // old gcc breaks stack variable alignement + // ref https://gcc.gnu.org/bugzilla/show_bug.cgi?id=16660 + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); +#endif + + __m128i _tmp0 = _mm_sub_epi16(_r0, _r2); + __m128i _tmp1 = _mm_add_epi16(_r1, _r2); + __m128i _tmp2 = _mm_sub_epi16(_r2, _r1); + __m128i _tmp3 = _mm_sub_epi16(_r3, _r1); + + _mm_store_si128((__m128i*)p0, _tmp0); + _mm_store_si128((__m128i*)p1, _tmp1); + _mm_store_si128((__m128i*)p2, _tmp2); + _mm_store_si128((__m128i*)p3, _tmp3); + + p0 += max_jj * 4 * 8; + p1 += max_jj * 4 * 8; + p2 += max_jj * 4 * 8; + p3 += max_jj * 4 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __SSE2__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[4][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 2 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 2 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 2 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + } + } + + tmp[0][m][0] = r00 - r20; + tmp[0][m][1] = r01 - r21; + tmp[1][m][0] = r10 + r20; + tmp[1][m][1] = r11 + r21; + tmp[2][m][0] = r20 - r10; + tmp[2][m][1] = r21 - r11; + tmp[3][m][0] = r30 - r10; + tmp[3][m][1] = r31 - r11; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + + p0[0] = r00 - r20; + p0[1] = r01 - r21; + p1[0] = r10 + r20; + p1[1] = r11 + r21; + p2[0] = r20 - r10; + p2[1] = r21 - r11; + p3[0] = r30 - r10; + p3[1] = r31 - r11; + + p0 += max_jj * 4 * 2; + p1 += max_jj * 4 * 2; + p2 += max_jj * 4 * 2; + p3 += max_jj * 4 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 2 + 1 < w) r1 = r0123[1]; + if (tj * 2 + 2 < w) r2 = r0123[2]; + if (tj * 2 + 3 < w) r3 = r0123[3]; + } + } + + tmp[0][m] = r0 - r2; + tmp[1][m] = r1 + r2; + tmp[2][m] = r2 - r1; + tmp[3][m] = r3 - r1; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + + p0[0] = r0 - r2; + p1[0] = r1 + r2; + p2[0] = r2 - r1; + p3[0] = r3 - r1; + + p0 += max_jj * 4; + p1 += max_jj * 4; + p2 += max_jj * 4; + p3 += max_jj * 4; + } + } + } +} + +static inline void conv3x3s1_winograd23_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[2][4] = { + // {1, 1, 1, 0}, + // {0, 1, -1, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 1) / 2; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + int tmp[2][4][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 16; + const int* r1 = r0 + max_jj * 16; + const int* r2 = r0 + max_jj * 16 * 2; + const int* r3 = r0 + max_jj * 16 * 3; + + for (int m = 0; m < 4; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_r0, _r1), _r2); + __m512i _tmp1 = _mm512_add_epi32(_mm512_sub_epi32(_r1, _r2), _r3); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + + r0 += max_jj * 4 * 16; + r1 += max_jj * 4 * 16; + r2 += max_jj * 4 * 16; + r3 += max_jj * 4 * 16; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + __m512i _r0 = _mm512_load_si512((const __m512i*)tmp[m][0]); + __m512i _r1 = _mm512_load_si512((const __m512i*)tmp[m][1]); + __m512i _r2 = _mm512_load_si512((const __m512i*)tmp[m][2]); + __m512i _r3 = _mm512_load_si512((const __m512i*)tmp[m][3]); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_r0, _r1), _r2); + __m512i _tmp1 = _mm512_add_epi32(_mm512_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm512_srai_epi32(_tmp0, 2); + _tmp1 = _mm512_srai_epi32(_tmp1, 2); + + if (out_elempack == 16) + { + _mm512_store_si512((__m512i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) + { + _mm512_store_si512((__m512i*)(outptr0 + 16), _tmp1); + } + } + if (out_elempack == 8) + { + int* outptr1 = outptr0 + N; + + _mm256_store_si256((__m256i*)outptr0, _mm512_extracti32x8_epi32(_tmp0, 0)); + _mm256_store_si256((__m256i*)outptr1, _mm512_extracti32x8_epi32(_tmp0, 1)); + if (tj * 2 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _mm512_extracti32x8_epi32(_tmp1, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 8), _mm512_extracti32x8_epi32(_tmp1, 1)); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + _mm_store_si128((__m128i*)outptr0, _mm512_extracti32x4_epi32(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm512_extracti32x4_epi32(_tmp0, 1)); + _mm_store_si128((__m128i*)outptr2, _mm512_extracti32x4_epi32(_tmp0, 2)); + _mm_store_si128((__m128i*)outptr3, _mm512_extracti32x4_epi32(_tmp0, 3)); + if (tj * 2 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm512_extracti32x4_epi32(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm512_extracti32x4_epi32(_tmp1, 1)); + _mm_store_si128((__m128i*)(outptr2 + 4), _mm512_extracti32x4_epi32(_tmp1, 2)); + _mm_store_si128((__m128i*)(outptr3 + 4), _mm512_extracti32x4_epi32(_tmp1, 3)); + } + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _mm512_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm512_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + int tmp[2][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_r0, _r1), _r2); + __m256i _tmp1 = _mm256_add_epi32(_mm256_sub_epi32(_r1, _r2), _r3); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); +#endif + + r0 += max_jj * 4 * 8; + r1 += max_jj * 4 * 8; + r2 += max_jj * 4 * 8; + r3 += max_jj * 4 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m256i _r0 = _mm256_loadu_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)tmp[m][3]); +#else + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); +#endif + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_r0, _r1), _r2); + __m256i _tmp1 = _mm256_add_epi32(_mm256_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm256_srai_epi32(_tmp0, 2); + _tmp1 = _mm256_srai_epi32(_tmp1, 2); + + if (out_elempack == 8) + { + _mm256_store_si256((__m256i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _tmp1); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + _mm_store_si128((__m128i*)outptr0, _mm256_extracti128_si256(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm256_extracti128_si256(_tmp0, 1)); + if (tj * 2 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm256_extracti128_si256(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm256_extracti128_si256(_tmp1, 1)); + } + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); + _mm256_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm256_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); +#else + int tmp0[8]; + int tmp1[8]; + _mm256_storeu_si256((__m256i*)tmp0, _tmp0); + _mm256_storeu_si256((__m256i*)tmp1, _tmp1); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + outptr4[0] = tmp0[4]; + outptr5[0] = tmp0[5]; + outptr6[0] = tmp0[6]; + outptr7[0] = tmp0[7]; + + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + outptr4[1] = tmp1[4]; + outptr5[1] = tmp1[5]; + outptr6[1] = tmp1[6]; + outptr7[1] = tmp1[7]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + int tmp[2][4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + + for (int m = 0; m < 4; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_r0, _r1), _r2); + __m128i _tmp1 = _mm_add_epi32(_mm_sub_epi32(_r1, _r2), _r3); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); +#endif + + r0 += max_jj * 4 * 4; + r1 += max_jj * 4 * 4; + r2 += max_jj * 4 * 4; + r3 += max_jj * 4 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); +#endif + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_r0, _r1), _r2); + __m128i _tmp1 = _mm_add_epi32(_mm_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm_srai_epi32(_tmp0, 2); + _tmp1 = _mm_srai_epi32(_tmp1, 2); + + if (out_elempack == 4) + { + _mm_store_si128((__m128i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) _mm_store_si128((__m128i*)(outptr0 + 4), _tmp1); + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(N)); + _mm_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); +#else + int tmp0[4]; + int tmp1[4]; + _mm_storeu_si128((__m128i*)tmp0, _tmp0); + _mm_storeu_si128((__m128i*)tmp1, _tmp1); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[2][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m][0] = r0[0] + r1[0] + r2[0]; + tmp[0][m][1] = r0[1] + r1[1] + r2[1]; + tmp[1][m][0] = r1[0] - r2[0] + r3[0]; + tmp[1][m][1] = r1[1] - r2[1] + r3[1]; + + r0 += max_jj * 4 * 2; + r1 += max_jj * 4 * 2; + r2 += max_jj * 4 * 2; + r3 += max_jj * 4 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp00 = tmp[m][0][0] + tmp[m][1][0] + tmp[m][2][0]; + int tmp01 = tmp[m][0][1] + tmp[m][1][1] + tmp[m][2][1]; + int tmp10 = tmp[m][1][0] - tmp[m][2][0] + tmp[m][3][0]; + int tmp11 = tmp[m][1][1] - tmp[m][2][1] + tmp[m][3][1]; + + tmp00 = tmp00 >> 2; + tmp01 = tmp01 >> 2; + tmp10 = tmp10 >> 2; + tmp11 = tmp11 >> 2; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[2][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m] = r0[0] + r1[0] + r2[0]; + tmp[1][m] = r1[0] - r2[0] + r3[0]; + + r0 += max_jj * 4; + r1 += max_jj * 4; + r2 += max_jj * 4; + r3 += max_jj * 4; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp0 = tmp[m][0] + tmp[m][1] + tmp[m][2]; + int tmp1 = tmp[m][1] - tmp[m][2] + tmp[m][3]; + + tmp0 = tmp0 >> 2; + tmp1 = tmp1 >> 2; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 2 + 1 < outw) outptr0[1] = tmp1; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd23_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + conv3x3s1_winograd23_int8_avx512vnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + conv3x3s1_winograd23_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd23_int8_avx2(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ + if (ncnn::cpu_support_x86_xop()) + { + conv3x3s1_winograd23_int8_xop(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif +#endif + + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 2n+2, winograd F(2,3) + int w_tiles = (outw + 1) / 2; + int h_tiles = (outh + 1) / 2; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 16; + + // NCNN_LOGE("conv3x3s1_winograd23_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 2u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 2u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + bool k_end = k + TILE_K >= K; + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk, k_end); + } + + // transform output + conv3x3s1_winograd23_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} + +static inline void conv3x3s1_winograd43_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const short ktm[6][3] = { + // {6, 0, 0}, + // {-4, -4, -4}, + // {-4, 4, -4}, + // {1, 2, 4}, + // {1, -2, 4}, + // {0, 0, 6} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[6][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 6; + tmp[1][m] = -r0 * 4 - r1 * 4 - r2 * 4; + tmp[2][m] = -r0 * 4 + r1 * 4 - r2 * 4; + tmp[3][m] = r0 + r1 * 2 + r2 * 4; + tmp[4][m] = r0 - r1 * 2 + r2 * 4; + tmp[5][m] = r2 * 6; + + k0 += 3; + } + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 6; + short z1 = -r0 * 4 - r1 * 4 - r2 * 4; + short z2 = -r0 * 4 + r1 * 4 - r2 * 4; + short z3 = r0 + r1 * 2 + r2 * 4; + short z4 = r0 - r1 * 2 + r2 * 4; + short z5 = r2 * 6; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp[4] = z4; + ptmp[5] = z5; + ptmp += 6; + } + } + } +} + +static void conv3x3s1_winograd43_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd43_transform_kernel_int8_avx2(kernel, AT, inch, outch, opt); + return; + } +#endif +#endif + + const int M = outch; + const int K = inch; + const int B = 36; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 4u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 4u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd43_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd43_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const float itm[4][4] = { + // {4, 0, -5, 0, 1, 0}, + // {0, -4, -4, 1, 1, 0}, + // {0, 4, -4, -1, 1, 0}, + // {0, -2, -1, 2, 1, 0}, + // {0, 2, -1, -2, 1, 0}, + // {0, 4, 0, -5, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w + 1) / 4; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __SSE2__ +#if __AVX512F__ + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + short tmp[6][6][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; + + __m256i _v2 = _mm256_set1_epi16(2); + __m256i _v4 = _mm256_set1_epi16(4); + __m256i _v5 = _mm256_set1_epi16(5); + + for (int m = 0; m < 6; m++) + { + __m256i _r0 = _mm256_setzero_si256(); + __m256i _r1 = _mm256_setzero_si256(); + __m256i _r2 = _mm256_setzero_si256(); + __m256i _r3 = _mm256_setzero_si256(); + __m256i _r4 = _mm256_setzero_si256(); + __m256i _r5 = _mm256_setzero_si256(); + + if (ti * 4 + m < h) + { + if (elempack == 16) + { + _r0 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)r0)); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 16))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 32))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 48))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 64))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 80))); + } + if (elempack == 8) + { + const signed char* r1 = r0 + N; + + _r0 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)r0), _mm_loadl_epi64((const __m128i*)r1))); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 8)), _mm_loadl_epi64((const __m128i*)(r1 + 8)))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 16)), _mm_loadl_epi64((const __m128i*)(r1 + 16)))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 24)), _mm_loadl_epi64((const __m128i*)(r1 + 24)))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 32)), _mm_loadl_epi64((const __m128i*)(r1 + 32)))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 40)), _mm_loadl_epi64((const __m128i*)(r1 + 40)))); + } + if (elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _r0 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)r0, sizeof(signed char)))); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 1), sizeof(signed char)))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 2), sizeof(signed char)))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 3), sizeof(signed char)))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 4), sizeof(signed char)))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 5), sizeof(signed char)))); + } + } + + __m256i _tmp12a = _mm256_sub_epi16(_r3, _mm256_mullo_epi16(_r1, _v4)); + __m256i _tmp12b = _mm256_sub_epi16(_r4, _mm256_mullo_epi16(_r2, _v4)); + __m256i _tmp34a = _mm256_mullo_epi16(_mm256_sub_epi16(_r3, _r1), _v2); + __m256i _tmp34b = _mm256_sub_epi16(_r4, _r2); + + __m256i _tmp0 = _mm256_add_epi16(_r4, _mm256_sub_epi16(_mm256_mullo_epi16(_r0, _v4), _mm256_mullo_epi16(_r2, _v5))); + __m256i _tmp1 = _mm256_add_epi16(_tmp12b, _tmp12a); + __m256i _tmp2 = _mm256_sub_epi16(_tmp12b, _tmp12a); + __m256i _tmp3 = _mm256_add_epi16(_tmp34b, _tmp34a); + __m256i _tmp4 = _mm256_sub_epi16(_tmp34b, _tmp34a); + __m256i _tmp5 = _mm256_add_epi16(_r5, _mm256_sub_epi16(_mm256_mullo_epi16(_r1, _v4), _mm256_mullo_epi16(_r3, _v5))); + + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); + _mm256_store_si256((__m256i*)tmp[4][m], _tmp4); + _mm256_store_si256((__m256i*)tmp[5][m], _tmp5); + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 16; + short* p1 = p0 + max_jj * 16; + short* p2 = p0 + max_jj * 16 * 2; + short* p3 = p0 + max_jj * 16 * 3; + short* p4 = p0 + max_jj * 16 * 4; + short* p5 = p0 + max_jj * 16 * 5; + + for (int m = 0; m < 6; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_load_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_load_si256((const __m256i*)tmp[m][5]); + + __m256i _tmp12a = _mm256_sub_epi16(_r3, _mm256_mullo_epi16(_r1, _v4)); + __m256i _tmp12b = _mm256_sub_epi16(_r4, _mm256_mullo_epi16(_r2, _v4)); + __m256i _tmp34a = _mm256_mullo_epi16(_mm256_sub_epi16(_r3, _r1), _v2); + __m256i _tmp34b = _mm256_sub_epi16(_r4, _r2); + + __m256i _tmp0 = _mm256_add_epi16(_r4, _mm256_sub_epi16(_mm256_mullo_epi16(_r0, _v4), _mm256_mullo_epi16(_r2, _v5))); + __m256i _tmp1 = _mm256_add_epi16(_tmp12b, _tmp12a); + __m256i _tmp2 = _mm256_sub_epi16(_tmp12b, _tmp12a); + __m256i _tmp3 = _mm256_add_epi16(_tmp34b, _tmp34a); + __m256i _tmp4 = _mm256_sub_epi16(_tmp34b, _tmp34a); + __m256i _tmp5 = _mm256_add_epi16(_r5, _mm256_sub_epi16(_mm256_mullo_epi16(_r1, _v4), _mm256_mullo_epi16(_r3, _v5))); + + _mm256_store_si256((__m256i*)p0, _tmp0); + _mm256_store_si256((__m256i*)p1, _tmp1); + _mm256_store_si256((__m256i*)p2, _tmp2); + _mm256_store_si256((__m256i*)p3, _tmp3); + _mm256_store_si256((__m256i*)p4, _tmp4); + _mm256_store_si256((__m256i*)p5, _tmp5); + + p0 += max_jj * 6 * 16; + p1 += max_jj * 6 * 16; + p2 += max_jj * 6 * 16; + p3 += max_jj * 6 * 16; + p4 += max_jj * 6 * 16; + p5 += max_jj * 6 * 16; + } + } + } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) +#endif // __AVX512F__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + short tmp[6][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; + + __m128i _v2 = _mm_set1_epi16(2); + __m128i _v4 = _mm_set1_epi16(4); + __m128i _v5 = _mm_set1_epi16(5); + + for (int m = 0; m < 6; m++) + { + __m128i _r0 = _mm_setzero_si128(); + __m128i _r1 = _mm_setzero_si128(); + __m128i _r2 = _mm_setzero_si128(); + __m128i _r3 = _mm_setzero_si128(); + __m128i _r4 = _mm_setzero_si128(); + __m128i _r5 = _mm_setzero_si128(); + + if (ti * 4 + m < h) + { + if (elempack == 8) + { + _r0 = _mm_loadl_epi64((const __m128i*)r0); + _r0 = _mm_unpacklo_epi8(_r0, _mm_cmpgt_epi8(_mm_setzero_si128(), _r0)); + if (tj * 4 + 1 < w) + { + _r1 = _mm_loadl_epi64((const __m128i*)(r0 + 8)); + _r1 = _mm_unpacklo_epi8(_r1, _mm_cmpgt_epi8(_mm_setzero_si128(), _r1)); + } + if (tj * 4 + 2 < w) + { + _r2 = _mm_loadl_epi64((const __m128i*)(r0 + 16)); + _r2 = _mm_unpacklo_epi8(_r2, _mm_cmpgt_epi8(_mm_setzero_si128(), _r2)); + } + if (tj * 4 + 3 < w) + { + _r3 = _mm_loadl_epi64((const __m128i*)(r0 + 24)); + _r3 = _mm_unpacklo_epi8(_r3, _mm_cmpgt_epi8(_mm_setzero_si128(), _r3)); + } + if (tj * 4 + 4 < w) + { + _r4 = _mm_loadl_epi64((const __m128i*)(r0 + 32)); + _r4 = _mm_unpacklo_epi8(_r4, _mm_cmpgt_epi8(_mm_setzero_si128(), _r4)); + } + if (tj * 4 + 5 < w) + { + _r5 = _mm_loadl_epi64((const __m128i*)(r0 + 40)); + _r5 = _mm_unpacklo_epi8(_r5, _mm_cmpgt_epi8(_mm_setzero_si128(), _r5)); + } + } + if (elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); +#if __AVX512F__ + _r0 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)))); + if (tj * 4 + 1 < w) _r1 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)))); + if (tj * 4 + 2 < w) _r2 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)))); + if (tj * 4 + 3 < w) _r3 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)))); + if (tj * 4 + 4 < w) _r4 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 4), _vindex, sizeof(signed char)))); + if (tj * 4 + 5 < w) _r5 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 5), _vindex, sizeof(signed char)))); +#else + __m128i _sindex8 = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m256i _sindex88 = _mm256_inserti128_si256(_mm256_castsi128_si256(_sindex8), _sindex8, 1); + __m256i _val0_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)), _sindex88); + _r0 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val0_32, 0), _mm256_extracti128_si256(_val0_32, 1))); + if (tj * 4 + 1 < w) + { + __m256i _val1_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)), _sindex88); + _r1 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val1_32, 0), _mm256_extracti128_si256(_val1_32, 1))); + } + if (tj * 4 + 2 < w) + { + __m256i _val2_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)), _sindex88); + _r2 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val2_32, 0), _mm256_extracti128_si256(_val2_32, 1))); + } + if (tj * 4 + 3 < w) + { + __m256i _val3_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)), _sindex88); + _r3 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val3_32, 0), _mm256_extracti128_si256(_val3_32, 1))); + } + if (tj * 4 + 4 < w) + { + __m256i _val4_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 4), _vindex, sizeof(signed char)), _sindex88); + _r4 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val4_32, 0), _mm256_extracti128_si256(_val4_32, 1))); + } + if (tj * 4 + 5 < w) + { + __m256i _val5_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 5), _vindex, sizeof(signed char)), _sindex88); + _r5 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val5_32, 0), _mm256_extracti128_si256(_val5_32, 1))); + } +#endif // __AVX512F__ +#else // __AVX2__ + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + __m128i _t0 = _mm_loadl_epi64((const __m128i*)r0); + __m128i _t1 = _mm_loadl_epi64((const __m128i*)r1); + __m128i _t2 = _mm_loadl_epi64((const __m128i*)r2); + __m128i _t3 = _mm_loadl_epi64((const __m128i*)r3); + __m128i _t4 = _mm_loadl_epi64((const __m128i*)r4); + __m128i _t5 = _mm_loadl_epi64((const __m128i*)r5); + __m128i _t6 = _mm_loadl_epi64((const __m128i*)r6); + __m128i _t7 = _mm_loadl_epi64((const __m128i*)r7); + + __m128i _t01 = _mm_unpacklo_epi8(_t0, _t1); + __m128i _t23 = _mm_unpacklo_epi8(_t2, _t3); + __m128i _t45 = _mm_unpacklo_epi8(_t4, _t5); + __m128i _t67 = _mm_unpacklo_epi8(_t6, _t7); + _t0 = _mm_unpacklo_epi16(_t01, _t23); + _t1 = _mm_unpacklo_epi16(_t45, _t67); + _t2 = _mm_unpacklo_epi32(_t0, _t1); + _t3 = _mm_unpackhi_epi32(_t0, _t1); + + __m128i _extt2 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t2); + __m128i _extt3 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t3); + + _r0 = _mm_unpacklo_epi8(_t2, _extt2); + if (tj * 4 + 1 < w) _r1 = _mm_unpackhi_epi8(_t2, _extt2); + if (tj * 4 + 2 < w) _r2 = _mm_unpacklo_epi8(_t3, _extt3); + if (tj * 4 + 3 < w) _r3 = _mm_unpackhi_epi8(_t3, _extt3); + if (tj * 4 + 4 < w) _r4 = _mm_setr_epi16(r0[4], r1[4], r2[4], r3[4], r4[4], r5[4], r6[4], r7[4]); + if (tj * 4 + 5 < w) _r5 = _mm_setr_epi16(r0[5], r1[5], r2[5], r3[5], r4[5], r5[5], r6[5], r7[5]); +#endif // __AVX2__ + } + } + + __m128i _tmp12a = _mm_sub_epi16(_r3, _mm_mullo_epi16(_r1, _v4)); + __m128i _tmp12b = _mm_sub_epi16(_r4, _mm_mullo_epi16(_r2, _v4)); + __m128i _tmp34a = _mm_mullo_epi16(_mm_sub_epi16(_r3, _r1), _v2); + __m128i _tmp34b = _mm_sub_epi16(_r4, _r2); + + __m128i _tmp0 = _mm_add_epi16(_r4, _mm_sub_epi16(_mm_mullo_epi16(_r0, _v4), _mm_mullo_epi16(_r2, _v5))); + __m128i _tmp1 = _mm_add_epi16(_tmp12b, _tmp12a); + __m128i _tmp2 = _mm_sub_epi16(_tmp12b, _tmp12a); + __m128i _tmp3 = _mm_add_epi16(_tmp34b, _tmp34a); + __m128i _tmp4 = _mm_sub_epi16(_tmp34b, _tmp34a); + __m128i _tmp5 = _mm_add_epi16(_r5, _mm_sub_epi16(_mm_mullo_epi16(_r1, _v4), _mm_mullo_epi16(_r3, _v5))); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); + _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4); + _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); + _mm_store_si128((__m128i*)tmp[4][m], _tmp4); + _mm_store_si128((__m128i*)tmp[5][m], _tmp5); +#endif + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + short* p4 = p0 + max_jj * 8 * 4; + short* p5 = p0 + max_jj * 8 * 5; + + for (int m = 0; m < 6; m++) + { +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_loadu_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_loadu_si128((const __m128i*)tmp[m][5]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_load_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_load_si128((const __m128i*)tmp[m][5]); +#endif + + __m128i _tmp12a = _mm_sub_epi16(_r3, _mm_mullo_epi16(_r1, _v4)); + __m128i _tmp12b = _mm_sub_epi16(_r4, _mm_mullo_epi16(_r2, _v4)); + __m128i _tmp34a = _mm_mullo_epi16(_mm_sub_epi16(_r3, _r1), _v2); + __m128i _tmp34b = _mm_sub_epi16(_r4, _r2); + + __m128i _tmp0 = _mm_add_epi16(_r4, _mm_sub_epi16(_mm_mullo_epi16(_r0, _v4), _mm_mullo_epi16(_r2, _v5))); + __m128i _tmp1 = _mm_add_epi16(_tmp12b, _tmp12a); + __m128i _tmp2 = _mm_sub_epi16(_tmp12b, _tmp12a); + __m128i _tmp3 = _mm_add_epi16(_tmp34b, _tmp34a); + __m128i _tmp4 = _mm_sub_epi16(_tmp34b, _tmp34a); + __m128i _tmp5 = _mm_add_epi16(_r5, _mm_sub_epi16(_mm_mullo_epi16(_r1, _v4), _mm_mullo_epi16(_r3, _v5))); + + _mm_store_si128((__m128i*)p0, _tmp0); + _mm_store_si128((__m128i*)p1, _tmp1); + _mm_store_si128((__m128i*)p2, _tmp2); + _mm_store_si128((__m128i*)p3, _tmp3); + _mm_store_si128((__m128i*)p4, _tmp4); + _mm_store_si128((__m128i*)p5, _tmp5); + + p0 += max_jj * 6 * 8; + p1 += max_jj * 6 * 8; + p2 += max_jj * 6 * 8; + p3 += max_jj * 6 * 8; + p4 += max_jj * 6 * 8; + p5 += max_jj * 6 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __SSE2__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[6][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + signed char r40 = 0; + signed char r41 = 0; + signed char r50 = 0; + signed char r51 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 4 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 4 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 4 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + if (tj * 4 + 4 < w) + { + r40 = r0[4]; + r41 = r1[4]; + } + if (tj * 4 + 5 < w) + { + r50 = r0[5]; + r51 = r1[5]; + } + } + } + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + tmp[0][m][0] = r40 + r00 * 4 - r20 * 5; + tmp[0][m][1] = r41 + r01 * 4 - r21 * 5; + tmp[1][m][0] = tmp120b + tmp120a; + tmp[1][m][1] = tmp121b + tmp121a; + tmp[2][m][0] = tmp120b - tmp120a; + tmp[2][m][1] = tmp121b - tmp121a; + tmp[3][m][0] = tmp340b + tmp340a; + tmp[3][m][1] = tmp341b + tmp341a; + tmp[4][m][0] = tmp340b - tmp340a; + tmp[4][m][1] = tmp341b - tmp341a; + tmp[5][m][0] = r50 + r10 * 4 - r30 * 5; + tmp[5][m][1] = r51 + r11 * 4 - r31 * 5; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + short* p4 = p0 + max_jj * 2 * 4; + short* p5 = p0 + max_jj * 2 * 5; + + for (int m = 0; m < 6; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + short r40 = tmp[m][4][0]; + short r41 = tmp[m][4][1]; + short r50 = tmp[m][5][0]; + short r51 = tmp[m][5][1]; + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + p0[0] = r40 + r00 * 4 - r20 * 5; + p0[1] = r41 + r01 * 4 - r21 * 5; + p1[0] = tmp120b + tmp120a; + p1[1] = tmp121b + tmp121a; + p2[0] = tmp120b - tmp120a; + p2[1] = tmp121b - tmp121a; + p3[0] = tmp340b + tmp340a; + p3[1] = tmp341b + tmp341a; + p4[0] = tmp340b - tmp340a; + p4[1] = tmp341b - tmp341a; + p5[0] = r50 + r10 * 4 - r30 * 5; + p5[1] = r51 + r11 * 4 - r31 * 5; + + p0 += max_jj * 6 * 2; + p1 += max_jj * 6 * 2; + p2 += max_jj * 6 * 2; + p3 += max_jj * 6 * 2; + p4 += max_jj * 6 * 2; + p5 += max_jj * 6 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[6][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + signed char r4 = 0; + signed char r5 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 4 + 1 < w) r1 = r0123[1]; + if (tj * 4 + 2 < w) r2 = r0123[2]; + if (tj * 4 + 3 < w) r3 = r0123[3]; + if (tj * 4 + 4 < w) r4 = r0123[4]; + if (tj * 4 + 5 < w) r5 = r0123[5]; + } + } + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + tmp[0][m] = r4 + r0 * 4 - r2 * 5; + tmp[1][m] = tmp12b + tmp12a; + tmp[2][m] = tmp12b - tmp12a; + tmp[3][m] = tmp34b + tmp34a; + tmp[4][m] = tmp34b - tmp34a; + tmp[5][m] = r5 + r1 * 4 - r3 * 5; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + short* p4 = p0 + max_jj * 4; + short* p5 = p0 + max_jj * 5; + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + short r4 = tmp[m][4]; + short r5 = tmp[m][5]; + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + p0[0] = r4 + r0 * 4 - r2 * 5; + p1[0] = tmp12b + tmp12a; + p2[0] = tmp12b - tmp12a; + p3[0] = tmp34b + tmp34a; + p4[0] = tmp34b - tmp34a; + p5[0] = r5 + r1 * 4 - r3 * 5; + + p0 += max_jj * 6; + p1 += max_jj * 6; + p2 += max_jj * 6; + p3 += max_jj * 6; + p4 += max_jj * 6; + p5 += max_jj * 6; + } + } + } +} + +static inline void conv3x3s1_winograd43_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[4][6] = { + // {1, 1, 1, 1, 1, 0}, + // {0, 1, -1, 2, -2, 0}, + // {0, 1, 1, 4, 4, 0}, + // {0, 1, -1, 8, -8, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 3) / 4; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + int tmp[4][6][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 16; + const int* r1 = r0 + max_jj * 16; + const int* r2 = r0 + max_jj * 16 * 2; + const int* r3 = r0 + max_jj * 16 * 3; + const int* r4 = r0 + max_jj * 16 * 4; + const int* r5 = r0 + max_jj * 16 * 5; + + for (int m = 0; m < 5; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + __m512i _r4 = _mm512_load_si512((const __m512i*)r4); + __m512i _r5 = _mm512_load_si512((const __m512i*)r5); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _mm512_slli_epi32(_r5, 2)); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + _mm512_store_si512((__m512i*)tmp[2][m], _tmp2); + _mm512_store_si512((__m512i*)tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 16; + r1 += max_jj * 6 * 16; + r2 += max_jj * 6 * 16; + r3 += max_jj * 6 * 16; + r4 += max_jj * 6 * 16; + r5 += max_jj * 6 * 16; + } + for (int m = 5; m < 6; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + __m512i _r4 = _mm512_load_si512((const __m512i*)r4); + __m512i _r5 = _mm512_load_si512((const __m512i*)r5); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _mm512_slli_epi32(_r5, 2)); + + _tmp0 = _mm512_slli_epi32(_tmp0, 2); + _tmp1 = _mm512_slli_epi32(_tmp1, 2); + _tmp2 = _mm512_slli_epi32(_tmp2, 2); + _tmp3 = _mm512_slli_epi32(_tmp3, 2); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + _mm512_store_si512((__m512i*)tmp[2][m], _tmp2); + _mm512_store_si512((__m512i*)tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 16; + r1 += max_jj * 6 * 16; + r2 += max_jj * 6 * 16; + r3 += max_jj * 6 * 16; + r4 += max_jj * 6 * 16; + r5 += max_jj * 6 * 16; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + __m512i _r0 = _mm512_load_si512((const __m512i*)tmp[m][0]); + __m512i _r1 = _mm512_load_si512((const __m512i*)tmp[m][1]); + __m512i _r2 = _mm512_load_si512((const __m512i*)tmp[m][2]); + __m512i _r3 = _mm512_load_si512((const __m512i*)tmp[m][3]); + __m512i _r4 = _mm512_load_si512((const __m512i*)tmp[m][4]); + __m512i _r5 = _mm512_load_si512((const __m512i*)tmp[m][5]); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m512 _v576 = _mm512_set1_ps(1.0 / 576); + _tmp0 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 16) + { + _mm512_store_si512((__m512i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm512_store_si512((__m512i*)(outptr0 + 16), _tmp1); + if (tj * 4 + 2 < outw) _mm512_store_si512((__m512i*)(outptr0 + 32), _tmp2); + if (tj * 4 + 3 < outw) _mm512_store_si512((__m512i*)(outptr0 + 48), _tmp3); + } + if (out_elempack == 8) + { + int* outptr1 = outptr0 + N; + + _mm256_store_si256((__m256i*)outptr0, _mm512_extracti32x8_epi32(_tmp0, 0)); + _mm256_store_si256((__m256i*)outptr1, _mm512_extracti32x8_epi32(_tmp0, 1)); + if (tj * 4 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _mm512_extracti32x8_epi32(_tmp1, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 8), _mm512_extracti32x8_epi32(_tmp1, 1)); + } + if (tj * 4 + 2 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 16), _mm512_extracti32x8_epi32(_tmp2, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 16), _mm512_extracti32x8_epi32(_tmp2, 1)); + } + if (tj * 4 + 3 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 24), _mm512_extracti32x8_epi32(_tmp3, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 24), _mm512_extracti32x8_epi32(_tmp3, 1)); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + _mm_store_si128((__m128i*)outptr0, _mm512_extracti32x4_epi32(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm512_extracti32x4_epi32(_tmp0, 1)); + _mm_store_si128((__m128i*)outptr2, _mm512_extracti32x4_epi32(_tmp0, 2)); + _mm_store_si128((__m128i*)outptr3, _mm512_extracti32x4_epi32(_tmp0, 3)); + if (tj * 4 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm512_extracti32x4_epi32(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm512_extracti32x4_epi32(_tmp1, 1)); + _mm_store_si128((__m128i*)(outptr2 + 4), _mm512_extracti32x4_epi32(_tmp1, 2)); + _mm_store_si128((__m128i*)(outptr3 + 4), _mm512_extracti32x4_epi32(_tmp1, 3)); + } + if (tj * 4 + 2 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 8), _mm512_extracti32x4_epi32(_tmp2, 0)); + _mm_store_si128((__m128i*)(outptr1 + 8), _mm512_extracti32x4_epi32(_tmp2, 1)); + _mm_store_si128((__m128i*)(outptr2 + 8), _mm512_extracti32x4_epi32(_tmp2, 2)); + _mm_store_si128((__m128i*)(outptr3 + 8), _mm512_extracti32x4_epi32(_tmp2, 3)); + } + if (tj * 4 + 3 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 12), _mm512_extracti32x4_epi32(_tmp3, 0)); + _mm_store_si128((__m128i*)(outptr1 + 12), _mm512_extracti32x4_epi32(_tmp3, 1)); + _mm_store_si128((__m128i*)(outptr2 + 12), _mm512_extracti32x4_epi32(_tmp3, 2)); + _mm_store_si128((__m128i*)(outptr3 + 12), _mm512_extracti32x4_epi32(_tmp3, 3)); + } + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _mm512_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm512_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm512_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm512_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + int tmp[4][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + const int* r4 = r0 + max_jj * 8 * 4; + const int* r5 = r0 + max_jj * 8 * 5; + + for (int m = 0; m < 5; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + __m256i _r4 = _mm256_load_si256((const __m256i*)r4); + __m256i _r5 = _mm256_load_si256((const __m256i*)r5); + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _mm256_slli_epi32(_r5, 2)); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_storeu_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_storeu_si256((__m256i*)tmp[3][m], _tmp3); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + for (int m = 5; m < 6; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + __m256i _r4 = _mm256_load_si256((const __m256i*)r4); + __m256i _r5 = _mm256_load_si256((const __m256i*)r5); + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _mm256_slli_epi32(_r5, 2)); + + _tmp0 = _mm256_slli_epi32(_tmp0, 2); + _tmp1 = _mm256_slli_epi32(_tmp1, 2); + _tmp2 = _mm256_slli_epi32(_tmp2, 2); + _tmp3 = _mm256_slli_epi32(_tmp3, 2); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_storeu_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_storeu_si256((__m256i*)tmp[3][m], _tmp3); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m256i _r0 = _mm256_loadu_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_loadu_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_loadu_si256((const __m256i*)tmp[m][5]); +#else + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_load_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_load_si256((const __m256i*)tmp[m][5]); +#endif + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m256 _v576 = _mm256_set1_ps(1.0 / 576); + _tmp0 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 8) + { + _mm256_store_si256((__m256i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm256_store_si256((__m256i*)(outptr0 + 8), _tmp1); + if (tj * 4 + 2 < outw) _mm256_store_si256((__m256i*)(outptr0 + 16), _tmp2); + if (tj * 4 + 3 < outw) _mm256_store_si256((__m256i*)(outptr0 + 24), _tmp3); + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + _mm_store_si128((__m128i*)(outptr0), _mm256_extracti128_si256(_tmp0, 0)); + _mm_store_si128((__m128i*)(outptr1), _mm256_extracti128_si256(_tmp0, 1)); + if (tj * 4 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm256_extracti128_si256(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm256_extracti128_si256(_tmp1, 1)); + } + if (tj * 4 + 2 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 8), _mm256_extracti128_si256(_tmp2, 0)); + _mm_store_si128((__m128i*)(outptr1 + 8), _mm256_extracti128_si256(_tmp2, 1)); + } + if (tj * 4 + 3 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 12), _mm256_extracti128_si256(_tmp3, 0)); + _mm_store_si128((__m128i*)(outptr1 + 12), _mm256_extracti128_si256(_tmp3, 1)); + } + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); + _mm256_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm256_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm256_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm256_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); +#else + int tmp0[8]; + int tmp1[8]; + int tmp2[8]; + int tmp3[8]; + _mm256_storeu_si256((__m256i*)tmp0, _tmp0); + _mm256_storeu_si256((__m256i*)tmp1, _tmp1); + _mm256_storeu_si256((__m256i*)tmp2, _tmp2); + _mm256_storeu_si256((__m256i*)tmp3, _tmp3); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + outptr4[0] = tmp0[4]; + outptr5[0] = tmp0[5]; + outptr6[0] = tmp0[6]; + outptr7[0] = tmp0[7]; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + outptr4[1] = tmp1[4]; + outptr5[1] = tmp1[5]; + outptr6[1] = tmp1[6]; + outptr7[1] = tmp1[7]; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp2[0]; + outptr1[2] = tmp2[1]; + outptr2[2] = tmp2[2]; + outptr3[2] = tmp2[3]; + outptr4[2] = tmp2[4]; + outptr5[2] = tmp2[5]; + outptr6[2] = tmp2[6]; + outptr7[2] = tmp2[7]; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp3[0]; + outptr1[3] = tmp3[1]; + outptr2[3] = tmp3[2]; + outptr3[3] = tmp3[3]; + outptr4[3] = tmp3[4]; + outptr5[3] = tmp3[5]; + outptr6[3] = tmp3[6]; + outptr7[3] = tmp3[7]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + int tmp[4][6][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + const int* r4 = r0 + max_jj * 4 * 4; + const int* r5 = r0 + max_jj * 4 * 5; + + for (int m = 0; m < 5; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + __m128i _r4 = _mm_load_si128((const __m128i*)r4); + __m128i _r5 = _mm_load_si128((const __m128i*)r5); + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _mm_slli_epi32(_r5, 2)); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + for (int m = 5; m < 6; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + __m128i _r4 = _mm_load_si128((const __m128i*)r4); + __m128i _r5 = _mm_load_si128((const __m128i*)r5); + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _mm_slli_epi32(_r5, 2)); + + _tmp0 = _mm_slli_epi32(_tmp0, 2); + _tmp1 = _mm_slli_epi32(_tmp1, 2); + _tmp2 = _mm_slli_epi32(_tmp2, 2); + _tmp3 = _mm_slli_epi32(_tmp3, 2); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_loadu_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_loadu_si128((const __m128i*)tmp[m][5]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_load_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_load_si128((const __m128i*)tmp[m][5]); +#endif + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m128 _v576 = _mm_set1_ps(1.0 / 576); + _tmp0 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 4) + { + _mm_store_si128((__m128i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm_store_si128((__m128i*)(outptr0 + 4), _tmp1); + if (tj * 4 + 2 < outw) _mm_store_si128((__m128i*)(outptr0 + 8), _tmp2); + if (tj * 4 + 3 < outw) _mm_store_si128((__m128i*)(outptr0 + 12), _tmp3); + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(N)); + _mm_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); +#else + int tmp0[4]; + int tmp1[4]; + int tmp2[4]; + int tmp3[4]; + _mm_storeu_si128((__m128i*)tmp0, _tmp0); + _mm_storeu_si128((__m128i*)tmp1, _tmp1); + _mm_storeu_si128((__m128i*)tmp2, _tmp2); + _mm_storeu_si128((__m128i*)tmp3, _tmp3); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp2[0]; + outptr1[2] = tmp2[1]; + outptr2[2] = tmp2[2]; + outptr3[2] = tmp2[3]; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp3[0]; + outptr1[3] = tmp3[1]; + outptr2[3] = tmp3[2]; + outptr3[3] = tmp3[3]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[4][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + const int* r4 = r0 + max_jj * 2 * 4; + const int* r5 = r0 + max_jj * 2 * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + for (int m = 5; m < 6; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp00 = tmp00 * 4; + tmp01 = tmp01 * 4; + tmp10 = tmp10 * 4; + tmp11 = tmp11 * 4; + tmp20 = tmp20 * 4; + tmp21 = tmp21 * 4; + tmp30 = tmp30 * 4; + tmp31 = tmp31 * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a0 = tmp[m][1][0] + tmp[m][2][0]; + int tmp02a1 = tmp[m][1][1] + tmp[m][2][1]; + int tmp02b0 = tmp[m][3][0] + tmp[m][4][0]; + int tmp02b1 = tmp[m][3][1] + tmp[m][4][1]; + int tmp13a0 = tmp[m][1][0] - tmp[m][2][0]; + int tmp13a1 = tmp[m][1][1] - tmp[m][2][1]; + int tmp13b0 = tmp[m][3][0] - tmp[m][4][0]; + int tmp13b1 = tmp[m][3][1] - tmp[m][4][1]; + + int tmp00 = tmp02a0 + tmp02b0 + tmp[m][0][0]; + int tmp01 = tmp02a1 + tmp02b1 + tmp[m][0][1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + tmp[m][5][0]; + int tmp31 = tmp13a1 + tmp13b1 * 8 + tmp[m][5][1]; + + tmp00 = tmp00 / 576; + tmp01 = tmp01 / 576; + tmp10 = tmp10 / 576; + tmp11 = tmp11 / 576; + tmp20 = tmp20 / 576; + tmp21 = tmp21 / 576; + tmp30 = tmp30 / 576; + tmp31 = tmp31 / 576; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp20; + outptr1[2] = tmp21; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp30; + outptr1[3] = tmp31; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[4][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + const int* r4 = r0 + max_jj * 4; + const int* r5 = r0 + max_jj * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + for (int m = 5; m < 6; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp0 = tmp0 * 4; + tmp1 = tmp1 * 4; + tmp2 = tmp2 * 4; + tmp3 = tmp3 * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a = tmp[m][1] + tmp[m][2]; + int tmp02b = tmp[m][3] + tmp[m][4]; + int tmp13a = tmp[m][1] - tmp[m][2]; + int tmp13b = tmp[m][3] - tmp[m][4]; + + int tmp0 = tmp02a + tmp02b + tmp[m][0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + tmp[m][5]; + + tmp0 = tmp0 / 576; + tmp1 = tmp1 / 576; + tmp2 = tmp2 / 576; + tmp3 = tmp3 / 576; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 4 + 1 < outw) outptr0[1] = tmp1; + if (tj * 4 + 2 < outw) outptr0[2] = tmp2; + if (tj * 4 + 3 < outw) outptr0[3] = tmp3; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd43_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + conv3x3s1_winograd43_int8_avx512vnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + conv3x3s1_winograd43_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd43_int8_avx2(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ + if (ncnn::cpu_support_x86_xop()) + { + conv3x3s1_winograd43_int8_xop(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif +#endif + + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 4n+2, winograd F(4,3) + int w_tiles = (outw + 3) / 4; + int h_tiles = (outh + 3) / 4; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 36; + + // NCNN_LOGE("conv3x3s1_winograd43_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 4u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + bool k_end = k + TILE_K >= K; + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk, k_end); + } + + // transform output + conv3x3s1_winograd43_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} diff --git a/src/layer/x86/convolution_im2col_gemm_int8.h b/src/layer/x86/convolution_im2col_gemm_int8.h index 56a4aa4763a..da504677a68 100644 --- a/src/layer/x86/convolution_im2col_gemm_int8.h +++ b/src/layer/x86/convolution_im2col_gemm_int8.h @@ -934,7 +934,6 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_CBAD); } - // TODO __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sumc, _MM_SHUFFLE(2, 0, 2, 0)); __m512i _tmp1 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); __m512i _tmp2 = _mm512_shuffle_i32x4(_sum8, _sum4, _MM_SHUFFLE(0, 2, 0, 2)); @@ -2547,12 +2546,12 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m512i _tmp7 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); - _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 3, 1)); - _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(1, 3, 1, 3)); + _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); - _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); - _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); _mm512_storeu_si512((__m512i*)outptr0, _sum0); _mm512_storeu_si512((__m512i*)(outptr0 + 16), _sum1); @@ -6142,14 +6141,13 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo } } -static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) +template +#if __AVX512F__ +void convolution_im2col_input_tile_int8_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#else // __AVX512F__ +void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#endif // __AVX512F__ { - if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - convolution_im2col_input_tile_conv1x1s1d1_int8(bottom_blob, B, j, max_jj, k, max_kk); - return; - } - const int w = bottom_blob.w; // const int channels = bottom_blob.c; const int elempack = bottom_blob.elempack; @@ -6206,288 +6204,468 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dxe = (j + jj + 14) % outw; int dxf = (j + jj + 15) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dyf) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int x08 = stride_w * dx8 + dilation_w * v0; - int x09 = stride_w * dx9 + dilation_w * v0; - int x0a = stride_w * dxa + dilation_w * v0; - int x0b = stride_w * dxb + dilation_w * v0; - int x0c = stride_w * dxc + dilation_w * v0; - int x0d = stride_w * dxd + dilation_w * v0; - int x0e = stride_w * dxe + dilation_w * v0; - int x0f = stride_w * dxf + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - int y08 = stride_h * dy8 + dilation_h * u0; - int y09 = stride_h * dy9 + dilation_h * u0; - int y0a = stride_h * dya + dilation_h * u0; - int y0b = stride_h * dyb + dilation_h * u0; - int y0c = stride_h * dyc + dilation_h * u0; - int y0d = stride_h * dyd + dilation_h * u0; - int y0e = stride_h * dye + dilation_h * u0; - int y0f = stride_h * dyf + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int x18 = stride_w * dx8 + dilation_w * v1; - int x19 = stride_w * dx9 + dilation_w * v1; - int x1a = stride_w * dxa + dilation_w * v1; - int x1b = stride_w * dxb + dilation_w * v1; - int x1c = stride_w * dxc + dilation_w * v1; - int x1d = stride_w * dxd + dilation_w * v1; - int x1e = stride_w * dxe + dilation_w * v1; - int x1f = stride_w * dxf + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - int y18 = stride_h * dy8 + dilation_h * u1; - int y19 = stride_h * dy9 + dilation_h * u1; - int y1a = stride_h * dya + dilation_h * u1; - int y1b = stride_h * dyb + dilation_h * u1; - int y1c = stride_h * dyc + dilation_h * u1; - int y1d = stride_h * dyd + dilation_h * u1; - int y1e = stride_h * dye + dilation_h * u1; - int y1f = stride_h * dyf + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - const signed char* sptr08 = img0.row(y08) + x08; - const signed char* sptr09 = img0.row(y09) + x09; - const signed char* sptr0a = img0.row(y0a) + x0a; - const signed char* sptr0b = img0.row(y0b) + x0b; - const signed char* sptr0c = img0.row(y0c) + x0c; - const signed char* sptr0d = img0.row(y0d) + x0d; - const signed char* sptr0e = img0.row(y0e) + x0e; - const signed char* sptr0f = img0.row(y0f) + x0f; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - const signed char* sptr18 = img1.row(y18) + x18; - const signed char* sptr19 = img1.row(y19) + x19; - const signed char* sptr1a = img1.row(y1a) + x1a; - const signed char* sptr1b = img1.row(y1b) + x1b; - const signed char* sptr1c = img1.row(y1c) + x1c; - const signed char* sptr1d = img1.row(y1d) + x1d; - const signed char* sptr1e = img1.row(y1e) + x1e; - const signed char* sptr1f = img1.row(y1f) + x1f; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp[16 + 0] = sptr08[0]; - pp[16 + 1] = sptr18[0]; - pp[16 + 2] = sptr09[0]; - pp[16 + 3] = sptr19[0]; - pp[16 + 4] = sptr0a[0]; - pp[16 + 5] = sptr1a[0]; - pp[16 + 6] = sptr0b[0]; - pp[16 + 7] = sptr1b[0]; - pp[16 + 8] = sptr0c[0]; - pp[16 + 9] = sptr1c[0]; - pp[16 + 10] = sptr0d[0]; - pp[16 + 11] = sptr1d[0]; - pp[16 + 12] = sptr0e[0]; - pp[16 + 13] = sptr1e[0]; - pp[16 + 14] = sptr0f[0]; - pp[16 + 15] = sptr1f[0]; - pp += 32; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _mm_store_si128((__m128i*)pp, _tmp0); + _mm_store_si128((__m128i*)(pp + 16), _tmp1); + pp += 32; + } + else if (stride_w == 2) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)sptr0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)sptr1); + __m256i _tmp0 = _mm256_unpacklo_epi8(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm256_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _r01 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _mm256_storeu_si256((__m256i*)pp, _r01); + pp += 32; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp[16 + 0] = sptr0[stride_w * 8]; + pp[16 + 1] = sptr1[stride_w * 8]; + pp[16 + 2] = sptr0[stride_w * 9]; + pp[16 + 3] = sptr1[stride_w * 9]; + pp[16 + 4] = sptr0[stride_w * 10]; + pp[16 + 5] = sptr1[stride_w * 10]; + pp[16 + 6] = sptr0[stride_w * 11]; + pp[16 + 7] = sptr1[stride_w * 11]; + pp[16 + 8] = sptr0[stride_w * 12]; + pp[16 + 9] = sptr1[stride_w * 12]; + pp[16 + 10] = sptr0[stride_w * 13]; + pp[16 + 11] = sptr1[stride_w * 13]; + pp[16 + 12] = sptr0[stride_w * 14]; + pp[16 + 13] = sptr1[stride_w * 14]; + pp[16 + 14] = sptr0[stride_w * 15]; + pp[16 + 15] = sptr1[stride_w * 15]; + pp += 32; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int x8 = stride_w * dx8 + dilation_w * v; - int x9 = stride_w * dx9 + dilation_w * v; - int xa = stride_w * dxa + dilation_w * v; - int xb = stride_w * dxb + dilation_w * v; - int xc = stride_w * dxc + dilation_w * v; - int xd = stride_w * dxd + dilation_w * v; - int xe = stride_w * dxe + dilation_w * v; - int xf = stride_w * dxf + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - int y8 = stride_h * dy8 + dilation_h * u; - int y9 = stride_h * dy9 + dilation_h * u; - int ya = stride_h * dya + dilation_h * u; - int yb = stride_h * dyb + dilation_h * u; - int yc = stride_h * dyc + dilation_h * u; - int yd = stride_h * dyd + dilation_h * u; - int ye = stride_h * dye + dilation_h * u; - int yf = stride_h * dyf + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; - const signed char* sptr8 = img.row(y8) + x8 * elempack; - const signed char* sptr9 = img.row(y9) + x9 * elempack; - const signed char* sptra = img.row(ya) + xa * elempack; - const signed char* sptrb = img.row(yb) + xb * elempack; - const signed char* sptrc = img.row(yc) + xc * elempack; - const signed char* sptrd = img.row(yd) + xd * elempack; - const signed char* sptre = img.row(ye) + xe * elempack; - const signed char* sptrf = img.row(yf) + xf * elempack; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); - __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); - __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); - __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); - __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); - __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); - __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); - __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); - __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); - __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); - __m128i _ref = _mm_unpacklo_epi16(_re, _rf); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi32(_r89, _rab); - _r5 = _mm_unpackhi_epi32(_r89, _rab); - _r6 = _mm_unpacklo_epi32(_rcd, _ref); - _r7 = _mm_unpackhi_epi32(_rcd, _ref); - _r8 = _mm_unpacklo_epi64(_r0, _r2); - _r9 = _mm_unpacklo_epi64(_r4, _r6); - _ra = _mm_unpackhi_epi64(_r0, _r2); - _rb = _mm_unpackhi_epi64(_r4, _r6); - _rc = _mm_unpacklo_epi64(_r1, _r3); - _rd = _mm_unpacklo_epi64(_r5, _r7); - _re = _mm_unpackhi_epi64(_r1, _r3); - _rf = _mm_unpackhi_epi64(_r5, _r7); - _mm_storeu_si128((__m128i*)pp, _r8); - _mm_storeu_si128((__m128i*)(pp + 16), _r9); - _mm_storeu_si128((__m128i*)(pp + 32), _ra); - _mm_storeu_si128((__m128i*)(pp + 48), _rb); - _mm_storeu_si128((__m128i*)(pp + 64), _rc); - _mm_storeu_si128((__m128i*)(pp + 80), _rd); - _mm_storeu_si128((__m128i*)(pp + 96), _re); - _mm_storeu_si128((__m128i*)(pp + 112), _rf); - pp += 128; + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 64)); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 72)); + __m128i _ra = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 80)); + __m128i _rb = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 88)); + __m128i _rc = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 96)); + __m128i _rd = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 104)); + __m128i _re = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 112)); + __m128i _rf = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 120)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp[8] = sptr[stride_w * 8]; + pp[9] = sptr[stride_w * 9]; + pp[10] = sptr[stride_w * 10]; + pp[11] = sptr[stride_w * 11]; + pp[12] = sptr[stride_w * 12]; + pp[13] = sptr[stride_w * 13]; + pp[14] = sptr[stride_w * 14]; + pp[15] = sptr[stride_w * 15]; + pp += 16; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr8[0]; - pp[9] = sptr9[0]; - pp[10] = sptra[0]; - pp[11] = sptrb[0]; - pp[12] = sptrc[0]; - pp[13] = sptrd[0]; - pp[14] = sptre[0]; - pp[15] = sptrf[0]; - pp += 16; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int x08 = stride_w * dx8 + dilation_w * v0; + int x09 = stride_w * dx9 + dilation_w * v0; + int x0a = stride_w * dxa + dilation_w * v0; + int x0b = stride_w * dxb + dilation_w * v0; + int x0c = stride_w * dxc + dilation_w * v0; + int x0d = stride_w * dxd + dilation_w * v0; + int x0e = stride_w * dxe + dilation_w * v0; + int x0f = stride_w * dxf + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + int y08 = stride_h * dy8 + dilation_h * u0; + int y09 = stride_h * dy9 + dilation_h * u0; + int y0a = stride_h * dya + dilation_h * u0; + int y0b = stride_h * dyb + dilation_h * u0; + int y0c = stride_h * dyc + dilation_h * u0; + int y0d = stride_h * dyd + dilation_h * u0; + int y0e = stride_h * dye + dilation_h * u0; + int y0f = stride_h * dyf + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int x18 = stride_w * dx8 + dilation_w * v1; + int x19 = stride_w * dx9 + dilation_w * v1; + int x1a = stride_w * dxa + dilation_w * v1; + int x1b = stride_w * dxb + dilation_w * v1; + int x1c = stride_w * dxc + dilation_w * v1; + int x1d = stride_w * dxd + dilation_w * v1; + int x1e = stride_w * dxe + dilation_w * v1; + int x1f = stride_w * dxf + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + int y18 = stride_h * dy8 + dilation_h * u1; + int y19 = stride_h * dy9 + dilation_h * u1; + int y1a = stride_h * dya + dilation_h * u1; + int y1b = stride_h * dyb + dilation_h * u1; + int y1c = stride_h * dyc + dilation_h * u1; + int y1d = stride_h * dyd + dilation_h * u1; + int y1e = stride_h * dye + dilation_h * u1; + int y1f = stride_h * dyf + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + const signed char* sptr08 = img0.row(y08) + x08; + const signed char* sptr09 = img0.row(y09) + x09; + const signed char* sptr0a = img0.row(y0a) + x0a; + const signed char* sptr0b = img0.row(y0b) + x0b; + const signed char* sptr0c = img0.row(y0c) + x0c; + const signed char* sptr0d = img0.row(y0d) + x0d; + const signed char* sptr0e = img0.row(y0e) + x0e; + const signed char* sptr0f = img0.row(y0f) + x0f; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + const signed char* sptr18 = img1.row(y18) + x18; + const signed char* sptr19 = img1.row(y19) + x19; + const signed char* sptr1a = img1.row(y1a) + x1a; + const signed char* sptr1b = img1.row(y1b) + x1b; + const signed char* sptr1c = img1.row(y1c) + x1c; + const signed char* sptr1d = img1.row(y1d) + x1d; + const signed char* sptr1e = img1.row(y1e) + x1e; + const signed char* sptr1f = img1.row(y1f) + x1f; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp[16 + 0] = sptr08[0]; + pp[16 + 1] = sptr18[0]; + pp[16 + 2] = sptr09[0]; + pp[16 + 3] = sptr19[0]; + pp[16 + 4] = sptr0a[0]; + pp[16 + 5] = sptr1a[0]; + pp[16 + 6] = sptr0b[0]; + pp[16 + 7] = sptr1b[0]; + pp[16 + 8] = sptr0c[0]; + pp[16 + 9] = sptr1c[0]; + pp[16 + 10] = sptr0d[0]; + pp[16 + 11] = sptr1d[0]; + pp[16 + 12] = sptr0e[0]; + pp[16 + 13] = sptr1e[0]; + pp[16 + 14] = sptr0f[0]; + pp[16 + 15] = sptr1f[0]; + pp += 32; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int x8 = stride_w * dx8 + dilation_w * v; + int x9 = stride_w * dx9 + dilation_w * v; + int xa = stride_w * dxa + dilation_w * v; + int xb = stride_w * dxb + dilation_w * v; + int xc = stride_w * dxc + dilation_w * v; + int xd = stride_w * dxd + dilation_w * v; + int xe = stride_w * dxe + dilation_w * v; + int xf = stride_w * dxf + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + int y8 = stride_h * dy8 + dilation_h * u; + int y9 = stride_h * dy9 + dilation_h * u; + int ya = stride_h * dya + dilation_h * u; + int yb = stride_h * dyb + dilation_h * u; + int yc = stride_h * dyc + dilation_h * u; + int yd = stride_h * dyd + dilation_h * u; + int ye = stride_h * dye + dilation_h * u; + int yf = stride_h * dyf + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + const signed char* sptr8 = img.row(y8) + x8 * elempack; + const signed char* sptr9 = img.row(y9) + x9 * elempack; + const signed char* sptra = img.row(ya) + xa * elempack; + const signed char* sptrb = img.row(yb) + xb * elempack; + const signed char* sptrc = img.row(yc) + xc * elempack; + const signed char* sptrd = img.row(yd) + xd * elempack; + const signed char* sptre = img.row(ye) + xe * elempack; + const signed char* sptrf = img.row(yf) + xf * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); + __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); + __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); + __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); + __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); + __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); + __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp[8] = sptr8[0]; + pp[9] = sptr9[0]; + pp[10] = sptra[0]; + pp[11] = sptrb[0]; + pp[12] = sptrc[0]; + pp[13] = sptrd[0]; + pp[14] = sptre[0]; + pp[15] = sptrf[0]; + pp += 16; + } } } } @@ -6511,168 +6689,298 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx6 = (j + jj + 6) % outw; int dx7 = (j + jj + 7) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy7) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp += 16; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _r01 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_tmp0), _mm_castsi128_ps(_tmp1), _MM_SHUFFLE(1, 0, 1, 0))); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp += 16; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi64(_r0, _r2); - _r5 = _mm_unpackhi_epi64(_r0, _r2); - _r6 = _mm_unpacklo_epi64(_r1, _r3); - _r7 = _mm_unpackhi_epi64(_r1, _r3); - _mm_storeu_si128((__m128i*)pp, _r4); - _mm_storeu_si128((__m128i*)(pp + 16), _r5); - _mm_storeu_si128((__m128i*)(pp + 32), _r6); - _mm_storeu_si128((__m128i*)(pp + 48), _r7); - pp += 64; + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp += 8; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp += 8; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp += 16; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp += 8; + } } } } @@ -6688,106 +6996,206 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx2 = (j + jj + 2) % outw; int dx3 = (j + jj + 3) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy3) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp += 8; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _r01 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_r01, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _r01 = _mm_shuffle_epi32(_r01, _MM_SHUFFLE(3, 1, 2, 0)); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp += 8; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr = img.row(y0) + x0 * elempack; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _mm_storeu_si128((__m128i*)pp, _r0); - _mm_storeu_si128((__m128i*)(pp + 16), _r1); - pp += 32; + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp += 4; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp += 4; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp += 8; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp += 4; + } } } } @@ -6799,44 +7207,154 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx0 = (j + jj) % outw; int dx1 = (j + jj + 1) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy1) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp += 4; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp += 2; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp += 2; + } } } + } + for (; jj < max_jj; jj++) + { + int dy = (j + jj) / outw; + int dx = (j + jj) % outw; + + int kk = 0; for (; kk < max_kk / elempack; kk++) { int p = (k / elempack + kk) / maxk; @@ -6846,29 +7364,1309 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; + int x = stride_w * dx + dilation_w * v; + int y = stride_h * dy + dilation_h * u; - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr = img.row(y) + x * elempack; #if __SSE2__ if (elempack == 8) { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)sptr)); + pp += 8; } #endif // __SSE2__ if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp += 2; + pp[0] = sptr[0]; + pp += 1; + } + } + } +} + +#if __AVX512F__ +template void convolution_im2col_input_tile_int8_avx512<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#else // __AVX512F__ +template void convolution_im2col_input_tile_int8<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#endif // __AVX512F__ + +static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) +{ + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_input_tile_conv1x1s1d1_int8(bottom_blob, B, j, max_jj, k, max_kk); + return; + } + + if (kernel_w == 1 && kernel_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 7 && kernel_h == 7 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + const int w = bottom_blob.w; + // const int channels = bottom_blob.c; + const int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int outw = (w - kernel_extent_w) / stride_w + 1; + + // j max_jj outw*outh split w and h + + // k max_kk pa*maxk*(inch/pa) split inch + + // k/max_kk shall be multiple of maxk + + const int maxk = kernel_w * kernel_h; + + signed char* pp = B; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dy4 = (j + jj + 4) / outw; + int dy5 = (j + jj + 5) / outw; + int dy6 = (j + jj + 6) / outw; + int dy7 = (j + jj + 7) / outw; + int dy8 = (j + jj + 8) / outw; + int dy9 = (j + jj + 9) / outw; + int dya = (j + jj + 10) / outw; + int dyb = (j + jj + 11) / outw; + int dyc = (j + jj + 12) / outw; + int dyd = (j + jj + 13) / outw; + int dye = (j + jj + 14) / outw; + int dyf = (j + jj + 15) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + int dx4 = (j + jj + 4) % outw; + int dx5 = (j + jj + 5) % outw; + int dx6 = (j + jj + 6) % outw; + int dx7 = (j + jj + 7) % outw; + int dx8 = (j + jj + 8) % outw; + int dx9 = (j + jj + 9) % outw; + int dxa = (j + jj + 10) % outw; + int dxb = (j + jj + 11) % outw; + int dxc = (j + jj + 12) % outw; + int dxd = (j + jj + 13) % outw; + int dxe = (j + jj + 14) % outw; + int dxf = (j + jj + 15) % outw; + + if (dy0 == dyf) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _mm_store_si128((__m128i*)pp, _tmp0); + _mm_store_si128((__m128i*)(pp + 16), _tmp1); + pp += 32; + } + else if (stride_w == 2) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)sptr0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)sptr1); + __m256i _tmp0 = _mm256_unpacklo_epi8(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm256_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _r01 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _mm256_storeu_si256((__m256i*)pp, _r01); + pp += 32; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp[16 + 0] = sptr0[stride_w * 8]; + pp[16 + 1] = sptr1[stride_w * 8]; + pp[16 + 2] = sptr0[stride_w * 9]; + pp[16 + 3] = sptr1[stride_w * 9]; + pp[16 + 4] = sptr0[stride_w * 10]; + pp[16 + 5] = sptr1[stride_w * 10]; + pp[16 + 6] = sptr0[stride_w * 11]; + pp[16 + 7] = sptr1[stride_w * 11]; + pp[16 + 8] = sptr0[stride_w * 12]; + pp[16 + 9] = sptr1[stride_w * 12]; + pp[16 + 10] = sptr0[stride_w * 13]; + pp[16 + 11] = sptr1[stride_w * 13]; + pp[16 + 12] = sptr0[stride_w * 14]; + pp[16 + 13] = sptr1[stride_w * 14]; + pp[16 + 14] = sptr0[stride_w * 15]; + pp[16 + 15] = sptr1[stride_w * 15]; + pp += 32; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 64)); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 72)); + __m128i _ra = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 80)); + __m128i _rb = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 88)); + __m128i _rc = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 96)); + __m128i _rd = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 104)); + __m128i _re = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 112)); + __m128i _rf = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 120)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp[8] = sptr[stride_w * 8]; + pp[9] = sptr[stride_w * 9]; + pp[10] = sptr[stride_w * 10]; + pp[11] = sptr[stride_w * 11]; + pp[12] = sptr[stride_w * 12]; + pp[13] = sptr[stride_w * 13]; + pp[14] = sptr[stride_w * 14]; + pp[15] = sptr[stride_w * 15]; + pp += 16; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int x08 = stride_w * dx8 + dilation_w * v0; + int x09 = stride_w * dx9 + dilation_w * v0; + int x0a = stride_w * dxa + dilation_w * v0; + int x0b = stride_w * dxb + dilation_w * v0; + int x0c = stride_w * dxc + dilation_w * v0; + int x0d = stride_w * dxd + dilation_w * v0; + int x0e = stride_w * dxe + dilation_w * v0; + int x0f = stride_w * dxf + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + int y08 = stride_h * dy8 + dilation_h * u0; + int y09 = stride_h * dy9 + dilation_h * u0; + int y0a = stride_h * dya + dilation_h * u0; + int y0b = stride_h * dyb + dilation_h * u0; + int y0c = stride_h * dyc + dilation_h * u0; + int y0d = stride_h * dyd + dilation_h * u0; + int y0e = stride_h * dye + dilation_h * u0; + int y0f = stride_h * dyf + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int x18 = stride_w * dx8 + dilation_w * v1; + int x19 = stride_w * dx9 + dilation_w * v1; + int x1a = stride_w * dxa + dilation_w * v1; + int x1b = stride_w * dxb + dilation_w * v1; + int x1c = stride_w * dxc + dilation_w * v1; + int x1d = stride_w * dxd + dilation_w * v1; + int x1e = stride_w * dxe + dilation_w * v1; + int x1f = stride_w * dxf + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + int y18 = stride_h * dy8 + dilation_h * u1; + int y19 = stride_h * dy9 + dilation_h * u1; + int y1a = stride_h * dya + dilation_h * u1; + int y1b = stride_h * dyb + dilation_h * u1; + int y1c = stride_h * dyc + dilation_h * u1; + int y1d = stride_h * dyd + dilation_h * u1; + int y1e = stride_h * dye + dilation_h * u1; + int y1f = stride_h * dyf + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + const signed char* sptr08 = img0.row(y08) + x08; + const signed char* sptr09 = img0.row(y09) + x09; + const signed char* sptr0a = img0.row(y0a) + x0a; + const signed char* sptr0b = img0.row(y0b) + x0b; + const signed char* sptr0c = img0.row(y0c) + x0c; + const signed char* sptr0d = img0.row(y0d) + x0d; + const signed char* sptr0e = img0.row(y0e) + x0e; + const signed char* sptr0f = img0.row(y0f) + x0f; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + const signed char* sptr18 = img1.row(y18) + x18; + const signed char* sptr19 = img1.row(y19) + x19; + const signed char* sptr1a = img1.row(y1a) + x1a; + const signed char* sptr1b = img1.row(y1b) + x1b; + const signed char* sptr1c = img1.row(y1c) + x1c; + const signed char* sptr1d = img1.row(y1d) + x1d; + const signed char* sptr1e = img1.row(y1e) + x1e; + const signed char* sptr1f = img1.row(y1f) + x1f; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp[16 + 0] = sptr08[0]; + pp[16 + 1] = sptr18[0]; + pp[16 + 2] = sptr09[0]; + pp[16 + 3] = sptr19[0]; + pp[16 + 4] = sptr0a[0]; + pp[16 + 5] = sptr1a[0]; + pp[16 + 6] = sptr0b[0]; + pp[16 + 7] = sptr1b[0]; + pp[16 + 8] = sptr0c[0]; + pp[16 + 9] = sptr1c[0]; + pp[16 + 10] = sptr0d[0]; + pp[16 + 11] = sptr1d[0]; + pp[16 + 12] = sptr0e[0]; + pp[16 + 13] = sptr1e[0]; + pp[16 + 14] = sptr0f[0]; + pp[16 + 15] = sptr1f[0]; + pp += 32; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int x8 = stride_w * dx8 + dilation_w * v; + int x9 = stride_w * dx9 + dilation_w * v; + int xa = stride_w * dxa + dilation_w * v; + int xb = stride_w * dxb + dilation_w * v; + int xc = stride_w * dxc + dilation_w * v; + int xd = stride_w * dxd + dilation_w * v; + int xe = stride_w * dxe + dilation_w * v; + int xf = stride_w * dxf + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + int y8 = stride_h * dy8 + dilation_h * u; + int y9 = stride_h * dy9 + dilation_h * u; + int ya = stride_h * dya + dilation_h * u; + int yb = stride_h * dyb + dilation_h * u; + int yc = stride_h * dyc + dilation_h * u; + int yd = stride_h * dyd + dilation_h * u; + int ye = stride_h * dye + dilation_h * u; + int yf = stride_h * dyf + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + const signed char* sptr8 = img.row(y8) + x8 * elempack; + const signed char* sptr9 = img.row(y9) + x9 * elempack; + const signed char* sptra = img.row(ya) + xa * elempack; + const signed char* sptrb = img.row(yb) + xb * elempack; + const signed char* sptrc = img.row(yc) + xc * elempack; + const signed char* sptrd = img.row(yd) + xd * elempack; + const signed char* sptre = img.row(ye) + xe * elempack; + const signed char* sptrf = img.row(yf) + xf * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); + __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); + __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); + __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); + __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); + __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); + __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp[8] = sptr8[0]; + pp[9] = sptr9[0]; + pp[10] = sptra[0]; + pp[11] = sptrb[0]; + pp[12] = sptrc[0]; + pp[13] = sptrd[0]; + pp[14] = sptre[0]; + pp[15] = sptrf[0]; + pp += 16; + } + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dy4 = (j + jj + 4) / outw; + int dy5 = (j + jj + 5) / outw; + int dy6 = (j + jj + 6) / outw; + int dy7 = (j + jj + 7) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + int dx4 = (j + jj + 4) % outw; + int dx5 = (j + jj + 5) % outw; + int dx6 = (j + jj + 6) % outw; + int dx7 = (j + jj + 7) % outw; + + if (dy0 == dy7) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _r01 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_tmp0), _mm_castsi128_ps(_tmp1), _MM_SHUFFLE(1, 0, 1, 0))); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp += 16; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp += 8; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp += 16; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp += 8; + } + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + + if (dy0 == dy3) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _r01 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_r01, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _r01 = _mm_shuffle_epi32(_r01, _MM_SHUFFLE(3, 1, 2, 0)); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp += 8; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp += 4; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp += 8; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp += 4; + } + } + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + + if (dy0 == dy1) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp += 2; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp += 2; + } } } } diff --git a/src/layer/x86/convolution_x86.cpp b/src/layer/x86/convolution_x86.cpp index f870a884746..09008985f12 100644 --- a/src/layer/x86/convolution_x86.cpp +++ b/src/layer/x86/convolution_x86.cpp @@ -46,16 +46,13 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" + +#include "convolution_3x3_winograd_int8.h" #endif // NCNN_INT8 #if __SSE2__ #include "convolution_3x3_pack1to4.h" -#if NCNN_INT8 -#include "convolution_3x3_pack8to4_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#endif // NCNN_INT8 - #if __AVX__ #include "convolution_3x3_pack1to8.h" #include "convolution_3x3_pack8to1.h" @@ -1231,32 +1228,14 @@ int Convolution_x86::create_pipeline_int8_x86(const Option& opt) const int maxk = kernel_w * kernel_h; const int num_input = weight_data_size / maxk / num_output; - int elempack = 1; - int out_elempack_int32 = 1; -#if __SSE2__ - if (opt.use_packing_layout) - { - elempack = num_input % 8 == 0 ? 8 : 1; - out_elempack_int32 = num_output % 4 == 0 ? 4 : 1; - } -#endif // __SSE2__ + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input > 8 || num_output > 8); - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __SSE2__ - } - else if (elempack == 8 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __SSE2__ - } - else if (elempack == 1 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd23_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) + if (opt.use_winograd_convolution && prefer_winograd && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - conv3x3s1_winograd23_transform_kernel_int8_sse(weight_data, weight_winograd23_data, num_input, num_output, opt); - // conv3x3s1_winograd43_transform_kernel_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); + if (opt.use_winograd43_convolution) + conv3x3s1_winograd43_transform_kernel_int8(weight_data, weight_winograd43_data, num_input, num_output, opt); + else + conv3x3s1_winograd23_transform_kernel_int8(weight_data, weight_winograd23_data, num_input, num_output, opt); } else if (opt.use_sgemm_convolution) { @@ -1352,6 +1331,8 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con if (top_blob_int32.empty()) return -100; + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input > 8 || num_output > 8); + int _nT = nT ? nT : opt.num_threads; if (nT != 0 && opt.num_threads != nT) { @@ -1360,22 +1341,12 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con NCNN_LOGE("opt.num_threads %d changed, convolution gemm will use load-time value %d", opt.num_threads, nT); } - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __SSE2__ - } - else if (elempack == 8 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __SSE2__ - } - else if (elempack == 1 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd23_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) + if (opt.use_winograd_convolution && prefer_winograd && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - conv3x3s1_winograd23_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd23_data, opt); - // conv3x3s1_winograd43_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); + if (opt.use_winograd43_convolution && !weight_winograd43_data.empty()) + conv3x3s1_winograd43_int8(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, _nT, opt); + else + conv3x3s1_winograd23_int8(bottom_blob_bordered, top_blob_int32, weight_winograd23_data, _nT, opt); } else if (opt.use_sgemm_convolution) { diff --git a/src/layer/x86/convolution_x86_avx2.cpp b/src/layer/x86/convolution_x86_avx2.cpp index 38f107ee086..49cded70213 100644 --- a/src/layer/x86/convolution_x86_avx2.cpp +++ b/src/layer/x86/convolution_x86_avx2.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_transform_kernel_packed_int8_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) @@ -46,24 +45,24 @@ void convolution_im2col_gemm_int8_avx2(const Mat& bottom_blob, Mat& top_blob, co } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_transform_kernel_int8(kernel, AT, inch, outch, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd23_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd43_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd43_transform_kernel_int8(kernel, AT, inch, outch, opt); } -void conv3x3s1_winograd43_pack8to4_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_avx512vnni.cpp b/src/layer/x86/convolution_x86_avx512vnni.cpp index f0ac51bbf85..8e34bb61309 100644 --- a/src/layer/x86/convolution_x86_avx512vnni.cpp +++ b/src/layer/x86/convolution_x86_avx512vnni.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_avx512vnni(const Mat& bottom_blob, Mat& top_bl } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_avxvnni.cpp b/src/layer/x86/convolution_x86_avxvnni.cpp index a8ef75bb968..aa1ba401856 100644 --- a/src/layer/x86/convolution_x86_avxvnni.cpp +++ b/src/layer/x86/convolution_x86_avxvnni.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_xop.cpp b/src/layer/x86/convolution_x86_xop.cpp index d954f554565..cacba8f07cd 100644 --- a/src/layer/x86/convolution_x86_xop.cpp +++ b/src/layer/x86/convolution_x86_xop.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_xop(const Mat& bottom_blob, Mat& top_blob, con } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/elu_x86.cpp b/src/layer/x86/elu_x86.cpp index 04f89b7b8eb..4fd1f84e1d4 100644 --- a/src/layer/x86/elu_x86.cpp +++ b/src/layer/x86/elu_x86.cpp @@ -29,9 +29,10 @@ int ELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { int w = bottom_top_blob.w; int h = bottom_top_blob.h; + int d = bottom_top_blob.d; int channels = bottom_top_blob.c; int elempack = bottom_top_blob.elempack; - int size = w * h * elempack; + int size = w * h * d * elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) diff --git a/src/layer/x86/gridsample_bicubic_apply_interpolation.h b/src/layer/x86/gridsample_bicubic_apply_interpolation.h new file mode 100644 index 00000000000..0b7be771d3b --- /dev/null +++ b/src/layer/x86/gridsample_bicubic_apply_interpolation.h @@ -0,0 +1,288 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ +static void cubic_interp1d_p16(__m512& coeffs0, __m512& coeffs1, __m512& coeffs2, __m512& coeffs3, const __m512& tx) +{ + const __m512 A = _mm512_set1_ps(-0.75f); + + const __m512 x0 = _mm512_add_ps(tx, _mm512_set1_ps(1.0f)); + const __m512& x1 = tx; + const __m512 x2 = _mm512_sub_ps(_mm512_set1_ps(1.0f), tx); + + coeffs0 = _mm512_sub_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_sub_ps(_mm512_mul_ps(A, x0), _mm512_mul_ps(_mm512_set1_ps(5.0f), A)), x0), _mm512_mul_ps(_mm512_set1_ps(8.0f), A)), x0), _mm512_mul_ps(_mm512_set1_ps(4), A)); + coeffs1 = _mm512_add_ps(_mm512_mul_ps(_mm512_mul_ps(_mm512_sub_ps(_mm512_mul_ps(_mm512_add_ps(A, _mm512_set1_ps(2.0f)), x1), _mm512_add_ps(A, _mm512_set1_ps(3.0f))), x1), x1), _mm512_set1_ps(1.0f)); + coeffs2 = _mm512_add_ps(_mm512_mul_ps(_mm512_mul_ps(_mm512_sub_ps(_mm512_mul_ps(_mm512_add_ps(A, _mm512_set1_ps(2.0f)), x2), _mm512_add_ps(A, _mm512_set1_ps(3.0f))), x2), x2), _mm512_set1_ps(1.0f)); + coeffs3 = _mm512_sub_ps(_mm512_sub_ps(_mm512_sub_ps(_mm512_set1_ps(1.0f), coeffs0), coeffs1), coeffs2); +} + +static void gridsample_2d_bicubic_apply_interpolation_p16(const Mat& src, Mat& dst, Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m512 x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3; + __m512 y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3; + __m512 value_f[4]; + cubic_interp1d_p16(x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3, _mm512_set1_ps(offset_value_ptr[0])); + cubic_interp1d_p16(y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3, _mm512_set1_ps(offset_value_ptr[1])); + + const int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int ii = 0; ii < 4; ii++) + { + __m512 x0_val = offset_ptr[0] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[0]) : _mm512_set1_ps(0); + __m512 x1_val = offset_ptr[1] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[1]) : _mm512_set1_ps(0); + __m512 x2_val = offset_ptr[2] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[2]) : _mm512_set1_ps(0); + __m512 x3_val = offset_ptr[3] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[3]) : _mm512_set1_ps(0); + + value_f[ii] = _mm512_mul_ps(x_coeffs0, x0_val); + value_f[ii] = _mm512_fmadd_ps(x_coeffs1, x1_val, value_f[ii]); + value_f[ii] = _mm512_fmadd_ps(x_coeffs2, x2_val, value_f[ii]); + value_f[ii] = _mm512_fmadd_ps(x_coeffs3, x3_val, value_f[ii]); + + offset_ptr += 4; + } + + __m512 _v = _mm512_mul_ps(y_coeffs0, value_f[0]); + _v = _mm512_fmadd_ps(y_coeffs1, value_f[1], _v); + _v = _mm512_fmadd_ps(y_coeffs2, value_f[2], _v); + _v = _mm512_fmadd_ps(y_coeffs3, value_f[3], _v); + _mm512_storeu_ps(dstptr, _v); + + dstptr += 16; + offset_value_ptr += 18; + } + } +} + +#endif // __AVX512F__ +static void cubic_interp1d_p8(__m256& coeffs0, __m256& coeffs1, __m256& coeffs2, __m256& coeffs3, const __m256& tx) +{ + const __m256 A = _mm256_set1_ps(-0.75f); + + const __m256 x0 = _mm256_add_ps(tx, _mm256_set1_ps(1)); + const __m256& x1 = tx; + const __m256 x2 = _mm256_sub_ps(_mm256_set1_ps(1), tx); + //const __m256 x3 = _mm256_add_ps(x2, _mm256_set1_ps(1)); + + coeffs0 = _mm256_sub_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_sub_ps(_mm256_mul_ps(A, x0), _mm256_mul_ps(_mm256_set1_ps(5.0f), A)), x0), _mm256_mul_ps(_mm256_set1_ps(8.0f), A)), x0), _mm256_mul_ps(_mm256_set1_ps(4), A)); + coeffs1 = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(_mm256_sub_ps(_mm256_mul_ps(_mm256_add_ps(A, _mm256_set1_ps(2.0f)), x1), _mm256_add_ps(A, _mm256_set1_ps(3.0f))), x1), x1), _mm256_set1_ps(1)); + coeffs2 = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(_mm256_sub_ps(_mm256_mul_ps(_mm256_add_ps(A, _mm256_set1_ps(2.0f)), x2), _mm256_add_ps(A, _mm256_set1_ps(3.0f))), x2), x2), _mm256_set1_ps(1)); + coeffs3 = _mm256_sub_ps(_mm256_sub_ps(_mm256_sub_ps(_mm256_set1_ps(1), coeffs0), coeffs1), coeffs2); +} + +static void gridsample_2d_bicubic_apply_interpolation_p8(const Mat& src, Mat& dst, Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m256 x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3; + __m256 y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3; + __m256 value_f[4]; + cubic_interp1d_p8(x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3, _mm256_set1_ps(offset_value_ptr[0])); + cubic_interp1d_p8(y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3, _mm256_set1_ps(offset_value_ptr[1])); + + const int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int ii = 0; ii < 4; ii++) + { + __m256 x0_val = offset_ptr[0] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[0]) : _mm256_set1_ps(0); + __m256 x1_val = offset_ptr[1] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[1]) : _mm256_set1_ps(0); + __m256 x2_val = offset_ptr[2] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[2]) : _mm256_set1_ps(0); + __m256 x3_val = offset_ptr[3] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[3]) : _mm256_set1_ps(0); + + value_f[ii] = _mm256_mul_ps(x_coeffs0, x0_val); + value_f[ii] = _mm256_comp_fmadd_ps(x_coeffs1, x1_val, value_f[ii]); + value_f[ii] = _mm256_comp_fmadd_ps(x_coeffs2, x2_val, value_f[ii]); + value_f[ii] = _mm256_comp_fmadd_ps(x_coeffs3, x3_val, value_f[ii]); + + offset_ptr += 4; + } + + __m256 _v = _mm256_mul_ps(y_coeffs0, value_f[0]); + _v = _mm256_comp_fmadd_ps(y_coeffs1, value_f[1], _v); + _v = _mm256_comp_fmadd_ps(y_coeffs2, value_f[2], _v); + _v = _mm256_comp_fmadd_ps(y_coeffs3, value_f[3], _v); + _mm256_storeu_ps(dstptr, _v); + + dstptr += 8; + offset_value_ptr += 18; + } + } +} + +#endif // __AVX__ +static void cubic_interp1d_p4(__m128& coeffs0, __m128& coeffs1, __m128& coeffs2, __m128& coeffs3, const __m128& tx) +{ + const __m128 A = _mm_set_ps1(-0.75f); + + const __m128 x0 = _mm_add_ps(tx, _mm_set_ps1(1.0f)); + const __m128& x1 = tx; + const __m128 x2 = _mm_sub_ps(_mm_set_ps1(1.0f), tx); + //const __m128 x3 = _mm_add_ps(x2, _mm_set_ps1(1.0f)); + + coeffs0 = _mm_sub_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_sub_ps(_mm_mul_ps(A, x0), _mm_mul_ps(_mm_set_ps1(5.0f), A)), x0), _mm_mul_ps(_mm_set_ps1(8.0f), A)), x0), _mm_mul_ps(_mm_set_ps1(4), A)); + coeffs1 = _mm_add_ps(_mm_mul_ps(_mm_mul_ps(_mm_sub_ps(_mm_mul_ps(_mm_add_ps(A, _mm_set_ps1(2.0f)), x1), _mm_add_ps(A, _mm_set_ps1(3.0f))), x1), x1), _mm_set_ps1(1.0f)); + coeffs2 = _mm_add_ps(_mm_mul_ps(_mm_mul_ps(_mm_sub_ps(_mm_mul_ps(_mm_add_ps(A, _mm_set_ps1(2.0f)), x2), _mm_add_ps(A, _mm_set_ps1(3.0f))), x2), x2), _mm_set_ps1(1.0f)); + coeffs3 = _mm_sub_ps(_mm_sub_ps(_mm_sub_ps(_mm_set_ps1(1.0f), coeffs0), coeffs1), coeffs2); +} + +static void gridsample_2d_bicubic_apply_interpolation_p4(const Mat& src, Mat& dst, Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + __m128 x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3; + __m128 y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3; + __m128 value_f[4]; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + cubic_interp1d_p4(x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3, _mm_set_ps1(offset_value_ptr[0])); + cubic_interp1d_p4(y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3, _mm_set_ps1(offset_value_ptr[1])); + + const int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int ii = 0; ii < 4; ii++) + { + __m128 x0_val = offset_ptr[0] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[0]) : _mm_set1_ps(0); + __m128 x1_val = offset_ptr[1] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[1]) : _mm_set1_ps(0); + __m128 x2_val = offset_ptr[2] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[2]) : _mm_set1_ps(0); + __m128 x3_val = offset_ptr[3] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[3]) : _mm_set1_ps(0); + + value_f[ii] = _mm_mul_ps(x_coeffs0, x0_val); + value_f[ii] = _mm_comp_fmadd_ps(x_coeffs1, x1_val, value_f[ii]); + value_f[ii] = _mm_comp_fmadd_ps(x_coeffs2, x2_val, value_f[ii]); + value_f[ii] = _mm_comp_fmadd_ps(x_coeffs3, x3_val, value_f[ii]); + + offset_ptr += 4; + } + + __m128 _v = _mm_mul_ps(y_coeffs0, value_f[0]); + _v = _mm_comp_fmadd_ps(y_coeffs1, value_f[1], _v); + _v = _mm_comp_fmadd_ps(y_coeffs2, value_f[2], _v); + _v = _mm_comp_fmadd_ps(y_coeffs3, value_f[3], _v); + _mm_storeu_ps(dstptr, _v); + + dstptr += 4; + offset_value_ptr += 18; + } + } +} + +#endif // __SSE2__ + +static inline void cubic_interp1d(float& coeffs0, float& coeffs1, float& coeffs2, float& coeffs3, float fx) +{ + const float A = -0.75f; + + float fx0 = fx + 1; + float fx1 = fx; + float fx2 = 1 - fx; + // float fx3 = 2 - fx; + + coeffs0 = A * fx0 * fx0 * fx0 - 5 * A * fx0 * fx0 + 8 * A * fx0 - 4 * A; + coeffs1 = (A + 2) * fx1 * fx1 * fx1 - (A + 3) * fx1 * fx1 + 1; + coeffs2 = (A + 2) * fx2 * fx2 * fx2 - (A + 3) * fx2 * fx2 + 1; + coeffs3 = 1.f - coeffs0 - coeffs1 - coeffs2; +} + +static void gridsample_2d_bicubic_apply_interpolation_p1(const Mat& src, Mat& dst, Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int x = 0; x < grid_size; x++) + { + float x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3; + float y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3; + float value_f[4]; + cubic_interp1d(x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3, offset_value_ptr[0]); + cubic_interp1d(y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3, offset_value_ptr[1]); + + const int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int ii = 0; ii < 4; ii++) + { + float x0_val = offset_ptr[0] >= 0 ? *(srcptr + offset_ptr[0]) : 0; + float x1_val = offset_ptr[1] >= 0 ? *(srcptr + offset_ptr[1]) : 0; + float x2_val = offset_ptr[2] >= 0 ? *(srcptr + offset_ptr[2]) : 0; + float x3_val = offset_ptr[3] >= 0 ? *(srcptr + offset_ptr[3]) : 0; + + value_f[ii] = x_coeffs0 * x0_val; + value_f[ii] = x_coeffs1 * x1_val + value_f[ii]; + value_f[ii] = x_coeffs2 * x2_val + value_f[ii]; + value_f[ii] = x_coeffs3 * x3_val + value_f[ii]; + + offset_ptr += 4; + } + + float _v = y_coeffs0 * value_f[0]; + _v = y_coeffs1 * value_f[1] + _v; + _v = y_coeffs2 * value_f[2] + _v; + _v = y_coeffs3 * value_f[3] + _v; + *dstptr = _v; + + dstptr++; + offset_value_ptr += 18; + } + } +} \ No newline at end of file diff --git a/src/layer/x86/gridsample_bicubic_compute_blob.h b/src/layer/x86/gridsample_bicubic_compute_blob.h new file mode 100644 index 00000000000..9006153d9d2 --- /dev/null +++ b/src/layer/x86/gridsample_bicubic_compute_blob.h @@ -0,0 +1,299 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +template +void gridsample_2d_bicubic_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h; + + float* offset_value_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 15 < grid_size; x += 16) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + + transpose2x8_ps(gx, gy); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gy = unormalize(_mm256_set1_ps(src.h), gy); + + __m256 gx_floor = _mm256_floor_ps(gx); + __m256 gy_floor = _mm256_floor_ps(gy); + + __m256 tx = _mm256_sub_ps(gx, gx_floor); + __m256 ty = _mm256_sub_ps(gy, gy_floor); + + __m256 gx0 = _mm256_add_ps(gx_floor, _mm256_set1_ps(-1)); + __m256 gx1 = gx_floor; + __m256 gx2 = _mm256_add_ps(gx_floor, _mm256_set1_ps(1)); + __m256 gx3 = _mm256_add_ps(gx2, _mm256_set1_ps(1)); + + gx0 = get_coord(_mm256_set1_ps(src.w), gx0); + gx1 = get_coord(_mm256_set1_ps(src.w), gx1); + gx2 = get_coord(_mm256_set1_ps(src.w), gx2); + gx3 = get_coord(_mm256_set1_ps(src.w), gx3); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(gx0, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx0, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(gx1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx1, _CMP_GT_OS)); + __m256 x2_in_range = _mm256_and_ps(_mm256_cmp_ps(gx2, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx2, _CMP_GT_OS)); + __m256 x3_in_range = _mm256_and_ps(_mm256_cmp_ps(gx3, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx3, _CMP_GT_OS)); + __m256 v0_offset_f[4], v1_offset_f[4], v2_offset_f[4], v3_offset_f[4]; + for (int i = 0; i < 4; i++) + { + gy = _mm256_add_ps(gy_floor, _mm256_set1_ps(-1.0f + i)); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + __m256 y_in_range = _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS)); + + __m256 gy_offset = _mm256_mul_ps(gy, _mm256_set1_ps(src.w)); + + v0_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx0), _mm256_set1_ps(src.elempack)); + v1_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx1), _mm256_set1_ps(src.elempack)); + v2_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx2), _mm256_set1_ps(src.elempack)); + v3_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx3), _mm256_set1_ps(src.elempack)); + + v0_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v0_offset_f[i], _mm256_and_ps(x0_in_range, y_in_range)); + v1_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v1_offset_f[i], _mm256_and_ps(x1_in_range, y_in_range)); + v2_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v2_offset_f[i], _mm256_and_ps(x2_in_range, y_in_range)); + v3_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v3_offset_f[i], _mm256_and_ps(x3_in_range, y_in_range)); + + v0_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v0_offset_f[i])); + v1_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v1_offset_f[i])); + v2_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v2_offset_f[i])); + v3_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v3_offset_f[i])); + } + + transpose8x18_ps(tx, ty, v0_offset_f[0], v1_offset_f[0], v2_offset_f[0], v3_offset_f[0], v0_offset_f[1], v1_offset_f[1], v2_offset_f[1], v3_offset_f[1], v0_offset_f[2], v1_offset_f[2], v2_offset_f[2], v3_offset_f[2], v0_offset_f[3], v1_offset_f[3], v2_offset_f[3], v3_offset_f[3]); + + _mm256_storeu_ps(offset_value_ptr, tx); + _mm256_storeu_ps(offset_value_ptr + 8, ty); + offset_value_ptr += 16; + for (int i = 0; i < 4; i++) + { + _mm256_storeu_ps(offset_value_ptr, v0_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 8, v1_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 16, v2_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 24, v3_offset_f[i]); + offset_value_ptr += 32; + } + gridptr += 16; + } + +#endif // __AVX__ + + for (; x < grid_size; x += 2) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + + sample_x = unormalize(src.w, sample_x); + sample_y = unormalize(src.h, sample_y); + + int x1 = floorf(sample_x); + int y1 = floorf(sample_y); + int x0 = x1 - 1; + int x2 = x1 + 1; + int x3 = x1 + 2; + + offset_value_ptr[0] = sample_x - static_cast(x1); + offset_value_ptr[1] = sample_y - static_cast(y1); + + x1 = get_coord(src.w, x1); + x0 = get_coord(src.w, x0); + x2 = get_coord(src.w, x2); + x3 = get_coord(src.w, x3); + + bool x1_in_range = (x1 > -1) & (x1 < src.w); + bool x0_in_range = (x0 > -1) & (x0 < src.w); + bool x2_in_range = (x2 > -1) & (x2 < src.w); + bool x3_in_range = (x3 > -1) & (x3 < src.w); + + int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int i = 0; i < 4; i++) + { + int gy = y1 + i - 1; + gy = get_coord(src.h, gy); + int offset_y = gy * src.w; + + bool y_in_range = (gy > -1) & (gy < src.h); + + bool v0_in_bound = (x0_in_range & y_in_range); + bool v1_in_bound = (x1_in_range & y_in_range); + bool v2_in_bound = (x2_in_range & y_in_range); + bool v3_in_bound = (x3_in_range & y_in_range); + + offset_ptr[0] = v0_in_bound ? (offset_y + x0) * src.elempack : -1.0; + offset_ptr[1] = v1_in_bound ? (offset_y + x1) * src.elempack : -1.0; + offset_ptr[2] = v2_in_bound ? (offset_y + x2) * src.elempack : -1.0; + offset_ptr[3] = v3_in_bound ? (offset_y + x3) * src.elempack : -1.0; + + offset_ptr += 4; + } + + gridptr += 2; + offset_value_ptr += 18; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gy = unormalize(_mm256_set1_ps(src.h), gy); + + __m256 gx_floor = _mm256_floor_ps(gx); + __m256 gy_floor = _mm256_floor_ps(gy); + + __m256 tx = _mm256_sub_ps(gx, gx_floor); + __m256 ty = _mm256_sub_ps(gy, gy_floor); + + __m256 gx0 = _mm256_add_ps(gx_floor, _mm256_set1_ps(-1)); + __m256 gx1 = gx_floor; + __m256 gx2 = _mm256_add_ps(gx_floor, _mm256_set1_ps(1)); + __m256 gx3 = _mm256_add_ps(gx2, _mm256_set1_ps(1)); + + gx0 = get_coord(_mm256_set1_ps(src.w), gx0); + gx1 = get_coord(_mm256_set1_ps(src.w), gx1); + gx2 = get_coord(_mm256_set1_ps(src.w), gx2); + gx3 = get_coord(_mm256_set1_ps(src.w), gx3); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(gx0, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx0, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(gx1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx1, _CMP_GT_OS)); + __m256 x2_in_range = _mm256_and_ps(_mm256_cmp_ps(gx2, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx2, _CMP_GT_OS)); + __m256 x3_in_range = _mm256_and_ps(_mm256_cmp_ps(gx3, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx3, _CMP_GT_OS)); + + __m256 v0_offset_f[4], v1_offset_f[4], v2_offset_f[4], v3_offset_f[4]; + for (int i = 0; i < 4; i++) + { + gy = _mm256_add_ps(gy_floor, _mm256_set1_ps(-1.0f + i)); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + __m256 y_in_range = _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS)); + + __m256 gy_offset = _mm256_mul_ps(gy, _mm256_set1_ps(src.w)); + + v0_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx0), _mm256_set1_ps(src.elempack)); + v1_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx1), _mm256_set1_ps(src.elempack)); + v2_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx2), _mm256_set1_ps(src.elempack)); + v3_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx3), _mm256_set1_ps(src.elempack)); + + v0_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v0_offset_f[i], _mm256_and_ps(x0_in_range, y_in_range)); + v1_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v1_offset_f[i], _mm256_and_ps(x1_in_range, y_in_range)); + v2_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v2_offset_f[i], _mm256_and_ps(x2_in_range, y_in_range)); + v3_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v3_offset_f[i], _mm256_and_ps(x3_in_range, y_in_range)); + + v0_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v0_offset_f[i])); + v1_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v1_offset_f[i])); + v2_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v2_offset_f[i])); + v3_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v3_offset_f[i])); + } + + transpose8x18_ps(tx, ty, v0_offset_f[0], v1_offset_f[0], v2_offset_f[0], v3_offset_f[0], v0_offset_f[1], v1_offset_f[1], v2_offset_f[1], v3_offset_f[1], v0_offset_f[2], v1_offset_f[2], v2_offset_f[2], v3_offset_f[2], v0_offset_f[3], v1_offset_f[3], v2_offset_f[3], v3_offset_f[3]); + + _mm256_storeu_ps(offset_value_ptr, tx); + _mm256_storeu_ps(offset_value_ptr + 8, ty); + offset_value_ptr += 16; + for (int i = 0; i < 4; i++) + { + _mm256_storeu_ps(offset_value_ptr, v0_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 8, v1_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 16, v2_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 24, v3_offset_f[i]); + offset_value_ptr += 32; + } + + gridptr_x += 8; + gridptr_y += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + + sample_x = unormalize(src.w, sample_x); + sample_y = unormalize(src.h, sample_y); + + int x1 = floorf(sample_x); + int y1 = floorf(sample_y); + int x0 = x1 - 1; + int x2 = x1 + 1; + int x3 = x1 + 2; + + offset_value_ptr[0] = sample_x - static_cast(x1); + offset_value_ptr[1] = sample_y - static_cast(y1); + + x1 = get_coord(src.w, x1); + x0 = get_coord(src.w, x0); + x2 = get_coord(src.w, x2); + x3 = get_coord(src.w, x3); + + bool x1_in_range = (x1 > -1) & (x1 < src.w); + bool x0_in_range = (x0 > -1) & (x0 < src.w); + bool x2_in_range = (x2 > -1) & (x2 < src.w); + bool x3_in_range = (x3 > -1) & (x3 < src.w); + + int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int i = 0; i < 4; i++) + { + int gy = y1 + i - 1; + gy = get_coord(src.h, gy); + int offset_y = gy * src.w; + + bool y_in_range = (gy > -1) & (gy < src.h); + + bool v0_in_bound = (x0_in_range & y_in_range); + bool v1_in_bound = (x1_in_range & y_in_range); + bool v2_in_bound = (x2_in_range & y_in_range); + bool v3_in_bound = (x3_in_range & y_in_range); + + offset_ptr[0] = v0_in_bound ? (offset_y + x0) * src.elempack : -1.0; + offset_ptr[1] = v1_in_bound ? (offset_y + x1) * src.elempack : -1.0; + offset_ptr[2] = v2_in_bound ? (offset_y + x2) * src.elempack : -1.0; + offset_ptr[3] = v3_in_bound ? (offset_y + x3) * src.elempack : -1.0; + + offset_ptr += 4; + } + + gridptr_x++; + gridptr_y++; + + offset_value_ptr += 18; + } + } +} diff --git a/src/layer/x86/gridsample_bilinear_apply_interpolation.h b/src/layer/x86/gridsample_bilinear_apply_interpolation.h new file mode 100644 index 00000000000..56f2767e587 --- /dev/null +++ b/src/layer/x86/gridsample_bilinear_apply_interpolation.h @@ -0,0 +1,370 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ +static void gridsample_2d_bilinear_apply_interpolation_p16(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 4; + + __m512 v00_val = offset_ptr[0] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[0]) : _mm512_set1_ps(0); + __m512 v01_val = offset_ptr[1] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[1]) : _mm512_set1_ps(0); + __m512 v10_val = offset_ptr[2] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[2]) : _mm512_set1_ps(0); + __m512 v11_val = offset_ptr[3] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[3]) : _mm512_set1_ps(0); + + __m512 value1 = _mm512_set1_ps(value_ptr[0]); + __m512 v0 = _mm512_fmadd_ps(v01_val, value1, _mm512_fnmadd_ps(v00_val, value1, v00_val)); + __m512 v1 = _mm512_fmadd_ps(v11_val, value1, _mm512_fnmadd_ps(v10_val, value1, v10_val)); + + __m512 value2 = _mm512_set1_ps(value_ptr[1]); + __m512 _v = _mm512_fmadd_ps(v1, value2, _mm512_fnmadd_ps(v0, value2, v0)); + _mm512_storeu_ps(dstptr, _v); + + dstptr += 16; + offset_value_ptr += 6; + } + } +} + +static void gridsample_3d_bilinear_apply_interpolation_p16(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 8; + + __m512 v000_val = offset_ptr[0] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[0]) : _mm512_set1_ps(0); + __m512 v001_val = offset_ptr[1] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[1]) : _mm512_set1_ps(0); + __m512 v010_val = offset_ptr[2] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[2]) : _mm512_set1_ps(0); + __m512 v011_val = offset_ptr[3] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[3]) : _mm512_set1_ps(0); + + __m512 v100_val = offset_ptr[4] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[4]) : _mm512_set1_ps(0); + __m512 v101_val = offset_ptr[5] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[5]) : _mm512_set1_ps(0); + __m512 v110_val = offset_ptr[6] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[6]) : _mm512_set1_ps(0); + __m512 v111_val = offset_ptr[7] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[7]) : _mm512_set1_ps(0); + + __m512 value = _mm512_set1_ps(value_ptr[0]); + __m512 v00 = _mm512_fmadd_ps(v001_val, value, _mm512_fnmadd_ps(v000_val, value, v000_val)); + __m512 v01 = _mm512_fmadd_ps(v011_val, value, _mm512_fnmadd_ps(v010_val, value, v010_val)); + __m512 v10 = _mm512_fmadd_ps(v101_val, value, _mm512_fnmadd_ps(v100_val, value, v100_val)); + __m512 v11 = _mm512_fmadd_ps(v111_val, value, _mm512_fnmadd_ps(v110_val, value, v110_val)); + + value = _mm512_set1_ps(value_ptr[1]); + __m512 v0 = _mm512_fmadd_ps(v01, value, _mm512_fnmadd_ps(v00, value, v00)); + __m512 v1 = _mm512_fmadd_ps(v11, value, _mm512_fnmadd_ps(v10, value, v10)); + + value = _mm512_set1_ps(value_ptr[2]); + __m512 _v = _mm512_fmadd_ps(v1, value, _mm512_fnmadd_ps(v0, value, v0)); + _mm512_storeu_ps(dstptr, _v); + + dstptr += 16; + offset_value_ptr += 11; + } + } +} + +#endif // __AVX512F__ + +static void gridsample_2d_bilinear_apply_interpolation_p8(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 4; + + __m256 v00_val = offset_ptr[0] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[0]) : _mm256_set1_ps(0); + __m256 v01_val = offset_ptr[1] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[1]) : _mm256_set1_ps(0); + __m256 v10_val = offset_ptr[2] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[2]) : _mm256_set1_ps(0); + __m256 v11_val = offset_ptr[3] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[3]) : _mm256_set1_ps(0); + + __m256 value1 = _mm256_set1_ps(value_ptr[0]); + __m256 v0 = _mm256_comp_fmadd_ps(v01_val, value1, _mm256_comp_fnmadd_ps(v00_val, value1, v00_val)); + __m256 v1 = _mm256_comp_fmadd_ps(v11_val, value1, _mm256_comp_fnmadd_ps(v10_val, value1, v10_val)); + + __m256 value2 = _mm256_set1_ps(value_ptr[1]); + __m256 _v = _mm256_comp_fmadd_ps(v1, value2, _mm256_comp_fnmadd_ps(v0, value2, v0)); + _mm256_storeu_ps(dstptr, _v); + + dstptr += 8; + offset_value_ptr += 6; + } + } +} +static void gridsample_3d_bilinear_apply_interpolation_p8(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 8; + + __m256 v000_val = offset_ptr[0] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[0]) : _mm256_set1_ps(0); + __m256 v001_val = offset_ptr[1] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[1]) : _mm256_set1_ps(0); + __m256 v010_val = offset_ptr[2] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[2]) : _mm256_set1_ps(0); + __m256 v011_val = offset_ptr[3] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[3]) : _mm256_set1_ps(0); + + __m256 v100_val = offset_ptr[4] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[4]) : _mm256_set1_ps(0); + __m256 v101_val = offset_ptr[5] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[5]) : _mm256_set1_ps(0); + __m256 v110_val = offset_ptr[6] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[6]) : _mm256_set1_ps(0); + __m256 v111_val = offset_ptr[7] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[7]) : _mm256_set1_ps(0); + + __m256 value = _mm256_set1_ps(value_ptr[0]); + __m256 v00 = _mm256_comp_fmadd_ps(v001_val, value, _mm256_comp_fnmadd_ps(v000_val, value, v000_val)); + __m256 v01 = _mm256_comp_fmadd_ps(v011_val, value, _mm256_comp_fnmadd_ps(v010_val, value, v010_val)); + __m256 v10 = _mm256_comp_fmadd_ps(v101_val, value, _mm256_comp_fnmadd_ps(v100_val, value, v100_val)); + __m256 v11 = _mm256_comp_fmadd_ps(v111_val, value, _mm256_comp_fnmadd_ps(v110_val, value, v110_val)); + + value = _mm256_set1_ps(value_ptr[1]); + __m256 v0 = _mm256_comp_fmadd_ps(v01, value, _mm256_comp_fnmadd_ps(v00, value, v00)); + __m256 v1 = _mm256_comp_fmadd_ps(v11, value, _mm256_comp_fnmadd_ps(v10, value, v10)); + + value = _mm256_set1_ps(value_ptr[2]); + __m256 _v = _mm256_comp_fmadd_ps(v1, value, _mm256_comp_fnmadd_ps(v0, value, v0)); + _mm256_storeu_ps(dstptr, _v); + + dstptr += 8; + offset_value_ptr += 11; + } + } +} +#endif // __AVX__ +static void gridsample_2d_bilinear_apply_interpolation_p4(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 4; + + __m128 v00_val = offset_ptr[0] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[0]) : _mm_set1_ps(0); + __m128 v01_val = offset_ptr[1] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[1]) : _mm_set1_ps(0); + __m128 v10_val = offset_ptr[2] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[2]) : _mm_set1_ps(0); + __m128 v11_val = offset_ptr[3] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[3]) : _mm_set1_ps(0); + + __m128 value1 = _mm_set1_ps(value_ptr[0]); + __m128 v0 = _mm_comp_fmadd_ps(v01_val, value1, _mm_comp_fnmadd_ps(v00_val, value1, v00_val)); + __m128 v1 = _mm_comp_fmadd_ps(v11_val, value1, _mm_comp_fnmadd_ps(v10_val, value1, v10_val)); + + __m128 value2 = _mm_set1_ps(value_ptr[1]); + __m128 _v = _mm_comp_fmadd_ps(v1, value2, _mm_comp_fnmadd_ps(v0, value2, v0)); + _mm_storeu_ps(dstptr, _v); + + dstptr += 4; + offset_value_ptr += 6; + } + } +} +static void gridsample_3d_bilinear_apply_interpolation_p4(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 8; + + __m128 v000_val = offset_ptr[0] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[0]) : _mm_set1_ps(0); + __m128 v001_val = offset_ptr[1] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[1]) : _mm_set1_ps(0); + __m128 v010_val = offset_ptr[2] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[2]) : _mm_set1_ps(0); + __m128 v011_val = offset_ptr[3] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[3]) : _mm_set1_ps(0); + + __m128 v100_val = offset_ptr[4] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[4]) : _mm_set1_ps(0); + __m128 v101_val = offset_ptr[5] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[5]) : _mm_set1_ps(0); + __m128 v110_val = offset_ptr[6] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[6]) : _mm_set1_ps(0); + __m128 v111_val = offset_ptr[7] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[7]) : _mm_set1_ps(0); + + __m128 value = _mm_set1_ps(value_ptr[0]); + __m128 v00 = _mm_comp_fmadd_ps(v001_val, value, _mm_comp_fnmadd_ps(v000_val, value, v000_val)); + __m128 v01 = _mm_comp_fmadd_ps(v011_val, value, _mm_comp_fnmadd_ps(v010_val, value, v010_val)); + __m128 v10 = _mm_comp_fmadd_ps(v101_val, value, _mm_comp_fnmadd_ps(v100_val, value, v100_val)); + __m128 v11 = _mm_comp_fmadd_ps(v111_val, value, _mm_comp_fnmadd_ps(v110_val, value, v110_val)); + + value = _mm_set1_ps(value_ptr[1]); + __m128 v0 = _mm_comp_fmadd_ps(v01, value, _mm_comp_fnmadd_ps(v00, value, v00)); + __m128 v1 = _mm_comp_fmadd_ps(v11, value, _mm_comp_fnmadd_ps(v10, value, v10)); + + value = _mm_set1_ps(value_ptr[2]); + __m128 _v = _mm_comp_fmadd_ps(v1, value, _mm_comp_fnmadd_ps(v0, value, v0)); + _mm_storeu_ps(dstptr, _v); + + dstptr += 4; + offset_value_ptr += 11; + } + } +} +#endif // __SSE2__ + +static void gridsample_2d_bilinear_apply_interpolation_p1(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int x = 0; x < grid_size; x++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 4; + + float v00 = offset_ptr[0] >= 0 ? *(srcptr + offset_ptr[0]) : 0; + float v01 = offset_ptr[1] >= 0 ? *(srcptr + offset_ptr[1]) : 0; + float v10 = offset_ptr[2] >= 0 ? *(srcptr + offset_ptr[2]) : 0; + float v11 = offset_ptr[3] >= 0 ? *(srcptr + offset_ptr[3]) : 0; + + float v0 = v00 * (1 - value_ptr[0]) + v01 * value_ptr[0]; + float v1 = v10 * (1 - value_ptr[0]) + v11 * value_ptr[0]; + + *dstptr = v0 * (1 - value_ptr[1]) + v1 * value_ptr[1]; + + dstptr++; + offset_value_ptr += 6; + } + } +} + +static void gridsample_3d_bilinear_apply_interpolation_p1(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int x = 0; x < grid_size; x++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 8; + + float v000 = offset_ptr[0] >= 0 ? *(srcptr + offset_ptr[0]) : 0; + float v001 = offset_ptr[1] >= 0 ? *(srcptr + offset_ptr[1]) : 0; + float v010 = offset_ptr[2] >= 0 ? *(srcptr + offset_ptr[2]) : 0; + float v011 = offset_ptr[3] >= 0 ? *(srcptr + offset_ptr[3]) : 0; + + float v100 = offset_ptr[4] >= 0 ? *(srcptr + offset_ptr[4]) : 0; + float v101 = offset_ptr[5] >= 0 ? *(srcptr + offset_ptr[5]) : 0; + float v110 = offset_ptr[6] >= 0 ? *(srcptr + offset_ptr[6]) : 0; + float v111 = offset_ptr[7] >= 0 ? *(srcptr + offset_ptr[7]) : 0; + + float v00 = v000 * (1 - value_ptr[0]) + v001 * value_ptr[0]; + float v01 = v010 * (1 - value_ptr[0]) + v011 * value_ptr[0]; + float v10 = v100 * (1 - value_ptr[0]) + v101 * value_ptr[0]; + float v11 = v110 * (1 - value_ptr[0]) + v111 * value_ptr[0]; + + float v0 = v00 * (1 - value_ptr[1]) + v01 * value_ptr[1]; + float v1 = v10 * (1 - value_ptr[1]) + v11 * value_ptr[1]; + + *dstptr = v0 * (1 - value_ptr[2]) + v1 * value_ptr[2]; + + dstptr++; + offset_value_ptr += 11; + } + } +} diff --git a/src/layer/x86/gridsample_bilinear_compute_blob.h b/src/layer/x86/gridsample_bilinear_compute_blob.h new file mode 100644 index 00000000000..f78e017c1c7 --- /dev/null +++ b/src/layer/x86/gridsample_bilinear_compute_blob.h @@ -0,0 +1,623 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +template +void gridsample_2d_bilinear_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h; + + float* offset_value_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 15 < grid_size; x += 16) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + + transpose2x8_ps(gx, gy); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + __m256 x_w = _mm256_floor_ps(gx); + __m256 y_n = _mm256_floor_ps(gy); + + __m256 x1 = _mm256_add_ps(x_w, _mm256_set1_ps(1)); + __m256 y1 = _mm256_add_ps(y_n, _mm256_set1_ps(1)); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(x_w, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x_w, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(x1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x1, _CMP_GT_OS)); + __m256 y0_in_range = _mm256_and_ps(_mm256_cmp_ps(y_n, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y_n, _CMP_GT_OS)); + __m256 y1_in_range = _mm256_and_ps(_mm256_cmp_ps(y1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y1, _CMP_GT_OS)); + + __m256 v00_in_range = _mm256_and_ps(x0_in_range, y0_in_range); + __m256 v01_in_range = _mm256_and_ps(x1_in_range, y0_in_range); + __m256 v10_in_range = _mm256_and_ps(x0_in_range, y1_in_range); + __m256 v11_in_range = _mm256_and_ps(x1_in_range, y1_in_range); + + __m256 nw_offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(y_n, _mm256_set1_ps(src.w), x_w), _mm256_set1_ps(src.elempack)); + __m256 ne_offset = _mm256_add_ps(nw_offset, _mm256_set1_ps(src.elempack)); + __m256 sw_offset = _mm256_comp_fmadd_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack), nw_offset); + __m256 se_offset = _mm256_add_ps(sw_offset, _mm256_set1_ps(src.elempack)); + + nw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), nw_offset, v00_in_range); + ne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), ne_offset, v01_in_range); + sw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), sw_offset, v10_in_range); + se_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), se_offset, v11_in_range); + + nw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(nw_offset)); + ne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(ne_offset)); + sw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(sw_offset)); + se_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(se_offset)); + + __m256 alpha = _mm256_sub_ps(gx, x_w); + __m256 beta = _mm256_sub_ps(gy, y_n); + + transpose8x6_ps(nw_offset, ne_offset, sw_offset, se_offset, alpha, beta); + + _mm256_storeu_ps(offset_value_ptr, nw_offset); + _mm256_storeu_ps(offset_value_ptr + 8, ne_offset); + _mm256_storeu_ps(offset_value_ptr + 16, sw_offset); + _mm256_storeu_ps(offset_value_ptr + 24, se_offset); + + _mm256_storeu_ps(offset_value_ptr + 32, alpha); + _mm256_storeu_ps(offset_value_ptr + 40, beta); + + gridptr += 16; + offset_value_ptr += 48; + } +#endif // __AVX__ + + for (; x < grid_size; x += 2) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + int x0 = (int)floorf(sample_x); + int y0 = (int)floorf(sample_y); + int x1 = x0 + 1; + int y1 = y0 + 1; + + bool x0_in_bound = (x0 > -1) & (x0 < src.w); + bool x1_in_bound = (x1 > -1) & (x1 < src.w); + bool y0_in_bound = (y0 > -1) & (y0 < src.h); + bool y1_in_bound = (y1 > -1) & (y1 < src.h); + + bool in_bound_00 = x0_in_bound & y0_in_bound; + bool in_bound_01 = x1_in_bound & y0_in_bound; + bool in_bound_10 = x0_in_bound & y1_in_bound; + bool in_bound_11 = x1_in_bound & y1_in_bound; + + int* offset_ptr = (int*)offset_value_ptr; + float* value_ptr = offset_value_ptr + 4; + + offset_ptr[0] = in_bound_00 ? (x0 + y0 * src.w) * src.elempack : -1.0; + offset_ptr[1] = in_bound_01 ? (x1 + y0 * src.w) * src.elempack : -1.0; + offset_ptr[2] = in_bound_10 ? (x0 + y1 * src.w) * src.elempack : -1.0; + offset_ptr[3] = in_bound_11 ? (x1 + y1 * src.w) * src.elempack : -1.0; + + value_ptr[0] = sample_x - x0; + value_ptr[1] = sample_y - y0; + + gridptr += 2; + offset_value_ptr += 6; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + __m256 x_w = _mm256_floor_ps(gx); + __m256 y_n = _mm256_floor_ps(gy); + + __m256 x1 = _mm256_add_ps(x_w, _mm256_set1_ps(1)); + __m256 y1 = _mm256_add_ps(y_n, _mm256_set1_ps(1)); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(x_w, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x_w, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(x1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x1, _CMP_GT_OS)); + __m256 y0_in_range = _mm256_and_ps(_mm256_cmp_ps(y_n, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y_n, _CMP_GT_OS)); + __m256 y1_in_range = _mm256_and_ps(_mm256_cmp_ps(y1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y1, _CMP_GT_OS)); + + __m256 v00_in_range = _mm256_and_ps(x0_in_range, y0_in_range); + __m256 v01_in_range = _mm256_and_ps(x1_in_range, y0_in_range); + __m256 v10_in_range = _mm256_and_ps(x0_in_range, y1_in_range); + __m256 v11_in_range = _mm256_and_ps(x1_in_range, y1_in_range); + + __m256 nw_offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(y_n, _mm256_set1_ps(src.w), x_w), _mm256_set1_ps(src.elempack)); + __m256 ne_offset = _mm256_add_ps(nw_offset, _mm256_set1_ps(src.elempack)); + __m256 sw_offset = _mm256_comp_fmadd_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack), nw_offset); + __m256 se_offset = _mm256_add_ps(sw_offset, _mm256_set1_ps(src.elempack)); + + nw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), nw_offset, v00_in_range); + ne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), ne_offset, v01_in_range); + sw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), sw_offset, v10_in_range); + se_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), se_offset, v11_in_range); + + nw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(nw_offset)); + ne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(ne_offset)); + sw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(sw_offset)); + se_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(se_offset)); + + __m256 alpha = _mm256_sub_ps(gx, x_w); + __m256 beta = _mm256_sub_ps(gy, y_n); + + transpose8x6_ps(nw_offset, ne_offset, sw_offset, se_offset, alpha, beta); + + _mm256_storeu_ps(offset_value_ptr, nw_offset); + _mm256_storeu_ps(offset_value_ptr + 8, ne_offset); + _mm256_storeu_ps(offset_value_ptr + 16, sw_offset); + _mm256_storeu_ps(offset_value_ptr + 24, se_offset); + + _mm256_storeu_ps(offset_value_ptr + 32, alpha); + _mm256_storeu_ps(offset_value_ptr + 40, beta); + + gridptr_x += 8; + gridptr_y += 8; + offset_value_ptr += 48; + } + +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + int x0 = (int)floorf(sample_x); + int y0 = (int)floorf(sample_y); + int x1 = x0 + 1; + int y1 = y0 + 1; + + bool x0_in_bound = (x0 > -1) & (x0 < src.w); + bool x1_in_bound = (x1 > -1) & (x1 < src.w); + bool y0_in_bound = (y0 > -1) & (y0 < src.h); + bool y1_in_bound = (y1 > -1) & (y1 < src.h); + + bool in_bound_00 = x0_in_bound & y0_in_bound; + bool in_bound_01 = x1_in_bound & y0_in_bound; + bool in_bound_10 = x0_in_bound & y1_in_bound; + bool in_bound_11 = x1_in_bound & y1_in_bound; + + int* offset_ptr = (int*)offset_value_ptr; + float* value_ptr = offset_value_ptr + 4; + + offset_ptr[0] = in_bound_00 ? (x0 + y0 * src.w) * src.elempack : -1.0; + offset_ptr[1] = in_bound_01 ? (x1 + y0 * src.w) * src.elempack : -1.0; + offset_ptr[2] = in_bound_10 ? (x0 + y1 * src.w) * src.elempack : -1.0; + offset_ptr[3] = in_bound_11 ? (x1 + y1 * src.w) * src.elempack : -1.0; + + value_ptr[0] = sample_x - x0; + value_ptr[1] = sample_y - y0; + + gridptr_x++; + gridptr_y++; + offset_value_ptr += 6; + } + } +} + +template +void gridsample_3d_bilinear_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h * grid.d; + + float* offset_value_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 23 < grid_size; x += 24) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + __m256 gz = _mm256_loadu_ps(gridptr + 16); + + transpose3x8_ps(gx, gy, gz); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gz = unormalize(_mm256_set1_ps(src.d), gz); + gz = get_coord(_mm256_set1_ps(src.d), gz); + + __m256 x_w = _mm256_floor_ps(gx); + __m256 y_n = _mm256_floor_ps(gy); + __m256 z_t = _mm256_floor_ps(gz); + + __m256 x1 = _mm256_add_ps(x_w, _mm256_set1_ps(1)); + __m256 y1 = _mm256_add_ps(y_n, _mm256_set1_ps(1)); + __m256 z1 = _mm256_add_ps(z_t, _mm256_set1_ps(1)); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(x_w, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x_w, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(x1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x1, _CMP_GT_OS)); + __m256 y0_in_range = _mm256_and_ps(_mm256_cmp_ps(y_n, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y_n, _CMP_GT_OS)); + __m256 y1_in_range = _mm256_and_ps(_mm256_cmp_ps(y1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y1, _CMP_GT_OS)); + __m256 z0_in_range = _mm256_and_ps(_mm256_cmp_ps(z_t, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), z_t, _CMP_GT_OS)); + __m256 z1_in_range = _mm256_and_ps(_mm256_cmp_ps(z1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), z1, _CMP_GT_OS)); + + __m256 v000_in_range, v010_in_range, v100_in_range, v110_in_range, v001_in_range, v011_in_range, v101_in_range, v111_in_range; + { + __m256 v00_in_range = _mm256_and_ps(x0_in_range, y0_in_range); + __m256 v01_in_range = _mm256_and_ps(x1_in_range, y0_in_range); + __m256 v10_in_range = _mm256_and_ps(x0_in_range, y1_in_range); + __m256 v11_in_range = _mm256_and_ps(x1_in_range, y1_in_range); + + v000_in_range = _mm256_and_ps(v00_in_range, z0_in_range); + v001_in_range = _mm256_and_ps(v01_in_range, z0_in_range); + v010_in_range = _mm256_and_ps(v10_in_range, z0_in_range); + v011_in_range = _mm256_and_ps(v11_in_range, z0_in_range); + + v100_in_range = _mm256_and_ps(v00_in_range, z1_in_range); + v101_in_range = _mm256_and_ps(v01_in_range, z1_in_range); + v110_in_range = _mm256_and_ps(v10_in_range, z1_in_range); + v111_in_range = _mm256_and_ps(v11_in_range, z1_in_range); + } + + __m256 tnw_offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), z_t, + _mm256_comp_fmadd_ps(y_n, _mm256_set1_ps(src.w), x_w)), + _mm256_set1_ps(src.elempack)); + __m256 tne_offset = _mm256_add_ps(tnw_offset, _mm256_set1_ps(src.elempack)); + __m256 tsw_offset = _mm256_add_ps(tnw_offset, _mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack))); + __m256 tse_offset = _mm256_add_ps(tsw_offset, _mm256_set1_ps(src.elempack)); + + __m256 bnw_offset = _mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), _mm256_set1_ps(src.elempack), tnw_offset); + __m256 bne_offset = _mm256_add_ps(bnw_offset, _mm256_set1_ps(src.elempack)); + __m256 bsw_offset = _mm256_add_ps(bnw_offset, _mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack))); + __m256 bse_offset = _mm256_add_ps(bsw_offset, _mm256_set1_ps(src.elempack)); + + tnw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tnw_offset, v000_in_range); + tne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tne_offset, v001_in_range); + tsw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tsw_offset, v010_in_range); + tse_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tse_offset, v011_in_range); + + bnw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bnw_offset, v100_in_range); + bne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bne_offset, v101_in_range); + bsw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bsw_offset, v110_in_range); + bse_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bse_offset, v111_in_range); + + tnw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tnw_offset)); + tne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tne_offset)); + tsw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tsw_offset)); + tse_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tse_offset)); + + bnw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bnw_offset)); + bne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bne_offset)); + bsw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bsw_offset)); + bse_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bse_offset)); + + __m256 alpha = _mm256_sub_ps(gx, x_w); + __m256 beta = _mm256_sub_ps(gy, y_n); + __m256 gamma = _mm256_sub_ps(gz, z_t); + + transpose8x11_ps(tnw_offset, tne_offset, tsw_offset, tse_offset, bnw_offset, bne_offset, bsw_offset, bse_offset, alpha, beta, gamma); + + _mm256_storeu_ps(offset_value_ptr, tnw_offset); + _mm256_storeu_ps(offset_value_ptr + 8, tne_offset); + _mm256_storeu_ps(offset_value_ptr + 16, tsw_offset); + _mm256_storeu_ps(offset_value_ptr + 24, tse_offset); + + _mm256_storeu_ps(offset_value_ptr + 32, bnw_offset); + _mm256_storeu_ps(offset_value_ptr + 40, bne_offset); + _mm256_storeu_ps(offset_value_ptr + 48, bsw_offset); + _mm256_storeu_ps(offset_value_ptr + 56, bse_offset); + + _mm256_storeu_ps(offset_value_ptr + 64, alpha); + _mm256_storeu_ps(offset_value_ptr + 72, beta); + _mm256_storeu_ps(offset_value_ptr + 80, gamma); + + gridptr += 24; + + offset_value_ptr += 88; + } +#endif // __AVX__ + + for (; x < grid_size; x += 3) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + float sample_z = *(gridptr + 2); + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + sample_z = unormalize(src.d, sample_z); + sample_z = get_coord(src.d, sample_z); + + int x0 = (int)floorf(sample_x); + int y0 = (int)floorf(sample_y); + int z0 = (int)floorf(sample_z); + int x1 = x0 + 1; + int y1 = y0 + 1; + int z1 = z0 + 1; + + bool x0_in_range = (x0 > -1) & (x0 < src.w); + bool y0_in_range = (y0 > -1) & (y0 < src.h); + bool z0_in_range = (z0 > -1) & (z0 < src.d); + bool x1_in_range = (x1 > -1) & (x1 < src.w); + bool y1_in_range = (y1 > -1) & (y1 < src.h); + bool z1_in_range = (z1 > -1) & (z1 < src.d); + + bool v00_in_range = x0_in_range & y0_in_range; + bool v01_in_range = x1_in_range & y0_in_range; + bool v10_in_range = x0_in_range & y1_in_range; + bool v11_in_range = x1_in_range & y1_in_range; + + bool in_bound_000 = v00_in_range & z0_in_range; + bool in_bound_001 = v01_in_range & z0_in_range; + bool in_bound_010 = v10_in_range & z0_in_range; + bool in_bound_011 = v11_in_range & z0_in_range; + + bool in_bound_100 = v00_in_range & z1_in_range; + bool in_bound_101 = v01_in_range & z1_in_range; + bool in_bound_110 = v10_in_range & z1_in_range; + bool in_bound_111 = v11_in_range & z1_in_range; + + int* offset_ptr = (int*)offset_value_ptr; + float* value_ptr = offset_value_ptr + 8; + + offset_ptr[0] = in_bound_000 ? (x0 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[1] = in_bound_001 ? (x1 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[2] = in_bound_010 ? (x0 + y1 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[3] = in_bound_011 ? (x1 + y1 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + + offset_ptr[4] = in_bound_100 ? (x0 + y0 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[5] = in_bound_101 ? (x1 + y0 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[6] = in_bound_110 ? (x0 + y1 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[7] = in_bound_111 ? (x1 + y1 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + + value_ptr[0] = sample_x - x0; + value_ptr[1] = sample_y - y0; + value_ptr[2] = sample_z - z0; + + gridptr += 3; + offset_value_ptr += 11; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + const float* gridptr_z = grid.channel(2); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + __m256 gz = _mm256_loadu_ps(gridptr_z); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gz = unormalize(_mm256_set1_ps(src.d), gz); + gz = get_coord(_mm256_set1_ps(src.d), gz); + + __m256 x_w = _mm256_floor_ps(gx); + __m256 y_n = _mm256_floor_ps(gy); + __m256 z_t = _mm256_floor_ps(gz); + + __m256 x1 = _mm256_add_ps(x_w, _mm256_set1_ps(1)); + __m256 y1 = _mm256_add_ps(y_n, _mm256_set1_ps(1)); + __m256 z1 = _mm256_add_ps(z_t, _mm256_set1_ps(1)); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(x_w, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x_w, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(x1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x1, _CMP_GT_OS)); + __m256 y0_in_range = _mm256_and_ps(_mm256_cmp_ps(y_n, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y_n, _CMP_GT_OS)); + __m256 y1_in_range = _mm256_and_ps(_mm256_cmp_ps(y1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y1, _CMP_GT_OS)); + __m256 z0_in_range = _mm256_and_ps(_mm256_cmp_ps(z_t, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), z_t, _CMP_GT_OS)); + __m256 z1_in_range = _mm256_and_ps(_mm256_cmp_ps(z1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), z1, _CMP_GT_OS)); + + __m256 v000_in_range, v010_in_range, v100_in_range, v110_in_range, v001_in_range, v011_in_range, v101_in_range, v111_in_range; + { + __m256 v00_in_range = _mm256_and_ps(x0_in_range, y0_in_range); + __m256 v01_in_range = _mm256_and_ps(x1_in_range, y0_in_range); + __m256 v10_in_range = _mm256_and_ps(x0_in_range, y1_in_range); + __m256 v11_in_range = _mm256_and_ps(x1_in_range, y1_in_range); + + v000_in_range = _mm256_and_ps(v00_in_range, z0_in_range); + v001_in_range = _mm256_and_ps(v01_in_range, z0_in_range); + v010_in_range = _mm256_and_ps(v10_in_range, z0_in_range); + v011_in_range = _mm256_and_ps(v11_in_range, z0_in_range); + + v100_in_range = _mm256_and_ps(v00_in_range, z1_in_range); + v101_in_range = _mm256_and_ps(v01_in_range, z1_in_range); + v110_in_range = _mm256_and_ps(v10_in_range, z1_in_range); + v111_in_range = _mm256_and_ps(v11_in_range, z1_in_range); + } + + __m256 tnw_offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), z_t, + _mm256_comp_fmadd_ps(y_n, _mm256_set1_ps(src.w), x_w)), + _mm256_set1_ps(src.elempack)); + __m256 tne_offset = _mm256_add_ps(tnw_offset, _mm256_set1_ps(src.elempack)); + __m256 tsw_offset = _mm256_add_ps(tnw_offset, _mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack))); + __m256 tse_offset = _mm256_add_ps(tsw_offset, _mm256_set1_ps(src.elempack)); + + __m256 bnw_offset = _mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), _mm256_set1_ps(src.elempack), tnw_offset); + __m256 bne_offset = _mm256_add_ps(bnw_offset, _mm256_set1_ps(src.elempack)); + __m256 bsw_offset = _mm256_add_ps(bnw_offset, _mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack))); + __m256 bse_offset = _mm256_add_ps(bsw_offset, _mm256_set1_ps(src.elempack)); + + tnw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tnw_offset, v000_in_range); + tne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tne_offset, v001_in_range); + tsw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tsw_offset, v010_in_range); + tse_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tse_offset, v011_in_range); + + bnw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bnw_offset, v100_in_range); + bne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bne_offset, v101_in_range); + bsw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bsw_offset, v110_in_range); + bse_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bse_offset, v111_in_range); + + tnw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tnw_offset)); + tne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tne_offset)); + tsw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tsw_offset)); + tse_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tse_offset)); + + bnw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bnw_offset)); + bne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bne_offset)); + bsw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bsw_offset)); + bse_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bse_offset)); + + __m256 alpha = _mm256_sub_ps(gx, x_w); + __m256 beta = _mm256_sub_ps(gy, y_n); + __m256 gamma = _mm256_sub_ps(gz, z_t); + + transpose8x11_ps(tnw_offset, tne_offset, tsw_offset, tse_offset, bnw_offset, bne_offset, bsw_offset, bse_offset, alpha, beta, gamma); + + _mm256_storeu_ps(offset_value_ptr, tnw_offset); + _mm256_storeu_ps(offset_value_ptr + 8, tne_offset); + _mm256_storeu_ps(offset_value_ptr + 16, tsw_offset); + _mm256_storeu_ps(offset_value_ptr + 24, tse_offset); + + _mm256_storeu_ps(offset_value_ptr + 32, bnw_offset); + _mm256_storeu_ps(offset_value_ptr + 40, bne_offset); + _mm256_storeu_ps(offset_value_ptr + 48, bsw_offset); + _mm256_storeu_ps(offset_value_ptr + 56, bse_offset); + + _mm256_storeu_ps(offset_value_ptr + 64, alpha); + _mm256_storeu_ps(offset_value_ptr + 72, beta); + _mm256_storeu_ps(offset_value_ptr + 80, gamma); + + gridptr_x += 8; + gridptr_y += 8; + gridptr_z += 8; + + offset_value_ptr += 88; + } +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + float sample_z = *gridptr_z; + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + sample_z = unormalize(src.d, sample_z); + sample_z = get_coord(src.d, sample_z); + + int x0 = (int)floorf(sample_x); + int y0 = (int)floorf(sample_y); + int z0 = (int)floorf(sample_z); + int x1 = x0 + 1; + int y1 = y0 + 1; + int z1 = z0 + 1; + + bool x0_in_range = (x0 > -1) & (x0 < src.w); + bool y0_in_range = (y0 > -1) & (y0 < src.h); + bool z0_in_range = (z0 > -1) & (z0 < src.d); + bool x1_in_range = (x1 > -1) & (x1 < src.w); + bool y1_in_range = (y1 > -1) & (y1 < src.h); + bool z1_in_range = (z1 > -1) & (z1 < src.d); + + bool v00_in_range = x0_in_range & y0_in_range; + bool v01_in_range = x1_in_range & y0_in_range; + bool v10_in_range = x0_in_range & y1_in_range; + bool v11_in_range = x1_in_range & y1_in_range; + + bool in_bound_000 = v00_in_range & z0_in_range; + bool in_bound_001 = v01_in_range & z0_in_range; + bool in_bound_010 = v10_in_range & z0_in_range; + bool in_bound_011 = v11_in_range & z0_in_range; + + bool in_bound_100 = v00_in_range & z1_in_range; + bool in_bound_101 = v01_in_range & z1_in_range; + bool in_bound_110 = v10_in_range & z1_in_range; + bool in_bound_111 = v11_in_range & z1_in_range; + + int* offset_ptr = (int*)offset_value_ptr; + float* value_ptr = offset_value_ptr + 8; + + offset_ptr[0] = in_bound_000 ? (x0 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[1] = in_bound_001 ? (x1 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[2] = in_bound_010 ? (x0 + y1 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[3] = in_bound_011 ? (x1 + y1 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + + offset_ptr[4] = in_bound_100 ? (x0 + y0 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[5] = in_bound_101 ? (x1 + y0 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[6] = in_bound_110 ? (x0 + y1 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[7] = in_bound_111 ? (x1 + y1 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + + value_ptr[0] = sample_x - x0; + value_ptr[1] = sample_y - y0; + value_ptr[2] = sample_z - z0; + + gridptr_x++; + gridptr_y++; + gridptr_z++; + offset_value_ptr += 11; + } + } +} diff --git a/src/layer/x86/gridsample_compute_blob.h b/src/layer/x86/gridsample_compute_blob.h new file mode 100644 index 00000000000..1fc903664cc --- /dev/null +++ b/src/layer/x86/gridsample_compute_blob.h @@ -0,0 +1,145 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "x86_usability.h" + +template +struct grid_sample_unormalize; + +template<> +struct grid_sample_unormalize +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + return _mm256_mul_ps(_mm256_div_ps(_mm256_add_ps(coord, _mm256_set1_ps(1)), _mm256_set1_ps(2)), _mm256_sub_ps(length, _mm256_set1_ps(1))); + } +#endif // __AVX__ + float operator()(int length, float coord) + { + return (coord + 1) / 2.f * (length - 1); + } +}; + +template<> +struct grid_sample_unormalize +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + return _mm256_div_ps(_mm256_comp_fmsub_ps(_mm256_add_ps(coord, _mm256_set1_ps(1)), length, _mm256_set1_ps(1)), _mm256_set1_ps(2)); + } +#endif // __AVX__ + float operator()(int length, float coord) + { + return ((coord + 1) * length - 1) / 2.f; + } +}; + +template +struct compute_coord +{ +#if __AVX__ + __m256 operator()(__m256 /*length*/, __m256 coord) + { + return coord; + } +#endif // __AVX__ + float operator()(int /*length*/, float coord) + { + return coord; + } +}; + +template +struct compute_coord +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + const __m256 border_x = _mm256_sub_ps(length, _mm256_set1_ps(1)); + + coord = _mm256_min_ps(border_x, _mm256_max_ps(coord, _mm256_setzero_ps())); + + return coord; + } +#endif // __AVX__ + float operator()(int length, float coord) + { + return std::min(length - 1.0f, std::max(coord, 0.0f)); + } +}; + +template<> +struct compute_coord +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + const __m256 border_x = _mm256_sub_ps(length, _mm256_set1_ps(1)); + + coord = abs256_ps(coord); + + __m256 reflectx_v = abs256_ps(_mm256_sub_ps(coord, border_x)); + coord = _mm256_sub_ps(border_x, reflectx_v); + + return coord; + } +#endif // __AVX__ + float operator()(int length, float coord) + { + coord = fabs(coord); + coord = (length - 1) - fabs(coord - (length - 1)); + + return std::min(length - 1.0f, std::max(coord, 0.0f)); + } +}; + +template<> +struct compute_coord +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + const __m256 border_x = _mm256_sub_ps(length, _mm256_set1_ps(1)); + + __m256 v0p5fp8 = _mm256_set1_ps(0.5f); + coord = _mm256_add_ps(coord, v0p5fp8); + + coord = abs256_ps(coord); + + __m256 reflectx_v = abs256_ps(_mm256_sub_ps(coord, length)); + coord = _mm256_sub_ps(length, reflectx_v); + + coord = _mm256_sub_ps(coord, v0p5fp8); + + _mm256_sub_ps(coord, v0p5fp8); + + coord = _mm256_min_ps(border_x, _mm256_max_ps(coord, _mm256_setzero_ps())); + + return coord; + } +#endif // __AVX__ + float operator()(int length, float coord) + { + coord = fabs(coord + 0.5f); + coord = length - fabs(coord - length) - 0.5; + + return std::min(length - 1.0f, std::max(coord, 0.0f)); + } +}; + +#include "gridsample_bilinear_compute_blob.h" +#include "gridsample_bicubic_compute_blob.h" +#include "gridsample_nearest_compute_blob.h" diff --git a/src/layer/x86/gridsample_nearest_apply_interpolation.h b/src/layer/x86/gridsample_nearest_apply_interpolation.h new file mode 100644 index 00000000000..e84cdc7de25 --- /dev/null +++ b/src/layer/x86/gridsample_nearest_apply_interpolation.h @@ -0,0 +1,126 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ +static void gridsample_nearest_apply_interpolation_p16(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const int* offset_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m512 _v = offset_ptr[0] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[0]) : _mm512_set1_ps(0); + offset_ptr++; + + _mm512_storeu_ps(dstptr, _v); + dstptr += 16; + } + } +} +#endif // __AVX512F__ + +static void gridsample_nearest_apply_interpolation_p8(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const int* offset_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m256 _v = offset_ptr[0] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[0]) : _mm256_set1_ps(0); + offset_ptr++; + + _mm256_storeu_ps(dstptr, _v); + dstptr += 8; + } + } +} +#endif // __AVX__ +static void gridsample_nearest_apply_interpolation_p4(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const int* offset_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m128 _v = offset_ptr[0] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[0]) : _mm_set1_ps(0); + offset_ptr++; + + _mm_storeu_ps(dstptr, _v); + dstptr += 4; + } + } +} + +#endif // __SSE2__ + +static void gridsample_nearest_apply_interpolation_p1(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const int* offset_ptr = offset_value.channel(0); + + for (int x = 0; x < grid_size; x++) + { + *dstptr = offset_ptr[0] >= 0 ? *(srcptr + offset_ptr[0]) : 0; + + offset_ptr++; + dstptr++; + } + } +} diff --git a/src/layer/x86/gridsample_nearest_compute_blob.h b/src/layer/x86/gridsample_nearest_compute_blob.h new file mode 100644 index 00000000000..a7a12066d21 --- /dev/null +++ b/src/layer/x86/gridsample_nearest_compute_blob.h @@ -0,0 +1,315 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +template +void gridsample_2d_nearest_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h; + + float* offset_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 15 < grid_size; x += 16) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + + transpose2x8_ps(gx, gy); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gx = _mm256_floor_ps(_mm256_add_ps(gx, _mm256_set1_ps(0.5f))); + gy = _mm256_floor_ps(_mm256_add_ps(gy, _mm256_set1_ps(0.5f))); + + __m256 v_in_range = _mm256_and_ps(_mm256_and_ps(_mm256_cmp_ps(gx, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx, _CMP_GT_OS)), + _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS))); + + __m256 offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(gy, _mm256_set1_ps(src.w), gx), _mm256_set1_ps(src.elempack)); + + offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), _mm256_castsi256_ps(_mm256_cvtps_epi32(offset)), v_in_range); + + _mm256_storeu_ps(offset_ptr, offset); + + gridptr += 16; + offset_ptr += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x += 2) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + int x0 = static_cast(floorf(sample_x + 0.5f)); + int y0 = static_cast(floorf(sample_y + 0.5f)); + + bool in_bound = ((x0 > -1) & (x0 < src.w) & (y0 > -1) & (y0 < src.h)); + + int* iptr = (int*)offset_ptr; + *iptr = in_bound ? (x0 + y0 * src.w) * src.elempack : -1.0; + + gridptr += 2; + offset_ptr++; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gx = _mm256_floor_ps(_mm256_add_ps(gx, _mm256_set1_ps(0.5f))); + gy = _mm256_floor_ps(_mm256_add_ps(gy, _mm256_set1_ps(0.5f))); + + __m256 v_in_range = _mm256_and_ps(_mm256_and_ps(_mm256_cmp_ps(gx, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx, _CMP_GT_OS)), + _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS))); + + __m256 offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(gy, _mm256_set1_ps(src.w), gx), _mm256_set1_ps(src.elempack)); + + offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), _mm256_castsi256_ps(_mm256_cvtps_epi32(offset)), v_in_range); + + _mm256_storeu_ps(offset_ptr, offset); + + gridptr_x += 8; + gridptr_y += 8; + offset_ptr += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + int x0 = static_cast(floorf(sample_x + 0.5f)); + int y0 = static_cast(floorf(sample_y + 0.5f)); + + bool in_bound = ((x0 > -1) & (x0 < src.w) & (y0 > -1) & (y0 < src.h)); + + int* iptr = (int*)offset_ptr; + *iptr = in_bound ? (x0 + y0 * src.w) * src.elempack : -1.0; + + gridptr_x++; + gridptr_y++; + + offset_ptr++; + } + } +} + +template +void gridsample_3d_nearest_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h * grid.d; + + float* offset_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 23 < grid_size; x += 24) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + __m256 gz = _mm256_loadu_ps(gridptr + 16); + + transpose3x8_ps(gx, gy, gz); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gz = unormalize(_mm256_set1_ps(src.d), gz); + gz = get_coord(_mm256_set1_ps(src.d), gz); + + gx = _mm256_floor_ps(_mm256_add_ps(gx, _mm256_set1_ps(0.5f))); + gy = _mm256_floor_ps(_mm256_add_ps(gy, _mm256_set1_ps(0.5f))); + gz = _mm256_floor_ps(_mm256_add_ps(gz, _mm256_set1_ps(0.5f))); + + __m256 v_in_range = _mm256_and_ps(_mm256_and_ps(_mm256_cmp_ps(gx, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx, _CMP_GT_OS)), + _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS))); + v_in_range = _mm256_and_ps(v_in_range, _mm256_and_ps(_mm256_cmp_ps(gz, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), gz, _CMP_GT_OS))); + + __m256 offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), gz, + _mm256_comp_fmadd_ps(gy, _mm256_set1_ps(src.w), gx)), + _mm256_set1_ps(src.elempack)); + + offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), _mm256_castsi256_ps(_mm256_cvtps_epi32(offset)), v_in_range); + + _mm256_storeu_ps(offset_ptr, offset); + + gridptr += 24; + offset_ptr += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x += 3) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + float sample_z = *(gridptr + 2); + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + sample_z = unormalize(src.d, sample_z); + sample_z = get_coord(src.d, sample_z); + + int x0 = static_cast(floorf(sample_x + 0.5f)); + int y0 = static_cast(floorf(sample_y + 0.5f)); + int z0 = static_cast(floorf(sample_z + 0.5f)); + + bool in_bound = ((x0 > -1) & (x0 < src.w) & (y0 > -1) & (y0 < src.h) & (z0 > -1) & (z0 < src.d)); + + int* iptr = (int*)offset_ptr; + *iptr = in_bound ? (x0 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + + gridptr += 3; + offset_ptr++; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + const float* gridptr_z = grid.channel(2); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + __m256 gz = _mm256_loadu_ps(gridptr_z); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gz = unormalize(_mm256_set1_ps(src.d), gz); + gz = get_coord(_mm256_set1_ps(src.d), gz); + + gx = _mm256_floor_ps(_mm256_add_ps(gx, _mm256_set1_ps(0.5f))); + gy = _mm256_floor_ps(_mm256_add_ps(gy, _mm256_set1_ps(0.5f))); + gz = _mm256_floor_ps(_mm256_add_ps(gz, _mm256_set1_ps(0.5f))); + + __m256 v_in_range = _mm256_and_ps(_mm256_and_ps(_mm256_cmp_ps(gx, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx, _CMP_GT_OS)), + _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS))); + v_in_range = _mm256_and_ps(v_in_range, _mm256_and_ps(_mm256_cmp_ps(gz, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), gz, _CMP_GT_OS))); + + __m256 offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), gz, + _mm256_comp_fmadd_ps(gy, _mm256_set1_ps(src.w), gx)), + _mm256_set1_ps(src.elempack)); + + offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), _mm256_castsi256_ps(_mm256_cvtps_epi32(offset)), v_in_range); + + _mm256_storeu_ps(offset_ptr, offset); + + gridptr_x += 8; + gridptr_y += 8; + gridptr_z += 8; + + offset_ptr += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + float sample_z = *gridptr_z; + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + sample_z = unormalize(src.d, sample_z); + sample_z = get_coord(src.d, sample_z); + + int x0 = static_cast(floorf(sample_x + 0.5f)); + int y0 = static_cast(floorf(sample_y + 0.5f)); + int z0 = static_cast(floorf(sample_z + 0.5f)); + + bool in_bound = ((x0 > -1) & (x0 < src.w) & (y0 > -1) & (y0 < src.h) & (z0 > -1) & (z0 < src.d)); + + int* iptr = (int*)offset_ptr; + *iptr = in_bound ? (x0 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + + gridptr_x++; + gridptr_y++; + gridptr_z++; + + offset_ptr++; + } + } +} diff --git a/src/layer/x86/gridsample_x86.cpp b/src/layer/x86/gridsample_x86.cpp new file mode 100644 index 00000000000..004bc4d0895 --- /dev/null +++ b/src/layer/x86/gridsample_x86.cpp @@ -0,0 +1,455 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "gridsample_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" +#include "cpu.h" + +namespace ncnn { + +#include "gridsample_compute_blob.h" +#include "gridsample_bilinear_apply_interpolation.h" +#include "gridsample_bicubic_apply_interpolation.h" +#include "gridsample_nearest_apply_interpolation.h" + +GridSample_x86::GridSample_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + +int GridSample_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& grid = bottom_blobs[1]; + Mat& top_blob = top_blobs[0]; + int elempack = bottom_blob.elempack; + + int channels = bottom_blob.c; + int dims = bottom_blob.dims; + size_t elemsize = bottom_blob.elemsize; + + int outw, outh, outd; + Mat offset_value_blob; + + Mat grid_p1; + if (grid.elempack != 1) + { + convert_packing(grid, grid_p1, 1, opt); + } + else + { + grid_p1 = grid; + } + + if (dims == 3) + { + outw = permute_fusion == 0 ? grid_p1.h : grid_p1.w; + outh = permute_fusion == 0 ? grid_p1.c : grid_p1.h; + + top_blob.create(outw, outh, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (sample_type == GridSample::Interpolation_BILINEAR) + { + offset_value_blob.create(outw, outh, elemsize * 6, 6, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + + if (sample_type == GridSample::Interpolation_NEAREST) + { + offset_value_blob.create(outw, outh, 1, elemsize, 1, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + + if (sample_type == GridSample::Interpolation_BICUBIC) + { + offset_value_blob.create(outw, outh, elemsize * 18, 18, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + } + + if (dims == 4) + { + outw = permute_fusion == 0 ? grid_p1.h : grid_p1.w; + outh = permute_fusion == 0 ? grid_p1.d : grid_p1.h; + outd = permute_fusion == 0 ? grid_p1.c : grid_p1.d; + + top_blob.create(outw, outh, outd, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (sample_type == GridSample::Interpolation_BILINEAR) + { + offset_value_blob.create(outw, outh, outd, elemsize * 11, 11, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + + if (sample_type == GridSample::Interpolation_NEAREST) + { + offset_value_blob.create(outw, outh, outd, 1, elemsize, 1, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + + if (sample_type == 3) + { + NCNN_LOGE("unsupported bicubic when dims == 4"); + return -100; + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + if (dims == 3) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_2d_bilinear_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_BICUBIC) + { + gridsample_2d_bicubic_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + } + else if (dims == 4) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_3d_bilinear_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + if (dims == 3) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_2d_bilinear_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_BICUBIC) + { + gridsample_2d_bicubic_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + } + else if (dims == 4) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_3d_bilinear_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + } + } + +#endif // __AVX__ + if (elempack == 4) + { + if (dims == 3) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_2d_bilinear_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_BICUBIC) + { + gridsample_2d_bicubic_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + } + else if (dims == 4) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_3d_bilinear_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + } + } + +#endif // __SSE2__ + + if (elempack == 1) + { + if (dims == 3) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_2d_bilinear_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_BICUBIC) + { + gridsample_2d_bicubic_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + } + else if (dims == 4) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_3d_bilinear_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/gridsample_x86.h b/src/layer/x86/gridsample_x86.h new file mode 100644 index 00000000000..826414eefc9 --- /dev/null +++ b/src/layer/x86/gridsample_x86.h @@ -0,0 +1,32 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_GRIDSAMPLE_X86_H +#define LAYER_GRIDSAMPLE_X86_H + +#include "gridsample.h" + +namespace ncnn { + +class GridSample_x86 : virtual public GridSample +{ +public: + GridSample_x86(); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_GRIDSAMPLE_X86_H diff --git a/src/layer/x86/interp_x86.cpp b/src/layer/x86/interp_x86.cpp index 193fbe99a2d..f08b6bb9aff 100644 --- a/src/layer/x86/interp_x86.cpp +++ b/src/layer/x86/interp_x86.cpp @@ -14,8 +14,6 @@ #include "interp_x86.h" -#include - #if __SSE2__ #include #if __AVX__ diff --git a/src/layer/x86/layernorm_x86.cpp b/src/layer/x86/layernorm_x86.cpp index ba293fb95c6..21840c6b3d2 100644 --- a/src/layer/x86/layernorm_x86.cpp +++ b/src/layer/x86/layernorm_x86.cpp @@ -14,7 +14,7 @@ #include "layernorm_x86.h" #include "x86_usability.h" -#include + #include #if __SSE2__ diff --git a/src/layer/x86/lrn_x86.cpp b/src/layer/x86/lrn_x86.cpp index cfcc8777b45..b05c75996a1 100644 --- a/src/layer/x86/lrn_x86.cpp +++ b/src/layer/x86/lrn_x86.cpp @@ -18,8 +18,6 @@ #include "avx_mathfun.h" #endif // __AVX__ -#include - namespace ncnn { int LRN_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const diff --git a/src/layer/x86/lstm_x86.cpp b/src/layer/x86/lstm_x86.cpp index 21f528361e2..6ba218e53d3 100644 --- a/src/layer/x86/lstm_x86.cpp +++ b/src/layer/x86/lstm_x86.cpp @@ -24,7 +24,6 @@ #include "x86_activation.h" #include "x86_usability.h" -#include #include "layer_type.h" namespace ncnn { diff --git a/src/layer/x86/mish_x86.cpp b/src/layer/x86/mish_x86.cpp index 2a45cabd2d9..e55a5e1f808 100644 --- a/src/layer/x86/mish_x86.cpp +++ b/src/layer/x86/mish_x86.cpp @@ -16,8 +16,6 @@ #include "x86_activation.h" -#include - namespace ncnn { Mish_x86::Mish_x86() diff --git a/src/layer/x86/quantize_x86.cpp b/src/layer/x86/quantize_x86.cpp index e4a9157cd24..8f7ee993673 100644 --- a/src/layer/x86/quantize_x86.cpp +++ b/src/layer/x86/quantize_x86.cpp @@ -14,8 +14,6 @@ #include "quantize_x86.h" -#include - #if __SSE2__ #include #if __AVX__ diff --git a/src/layer/x86/roialign_x86.cpp b/src/layer/x86/roialign_x86.cpp index 7c5be4b751e..0519376770f 100644 --- a/src/layer/x86/roialign_x86.cpp +++ b/src/layer/x86/roialign_x86.cpp @@ -14,8 +14,6 @@ #include "roialign_x86.h" -#include - namespace ncnn { // adapted from detectron2 diff --git a/src/layer/x86/selu_x86.cpp b/src/layer/x86/selu_x86.cpp new file mode 100644 index 00000000000..0980673957c --- /dev/null +++ b/src/layer/x86/selu_x86.cpp @@ -0,0 +1,135 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "selu_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +namespace ncnn { + +SELU_x86::SELU_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + +int SELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int d = bottom_top_blob.d; + int elempack = bottom_top_blob.elempack; + int channels = bottom_top_blob.c; + int size = w * h * d * elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + int i = 0; + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _zero512 = _mm512_setzero_ps(); + __m512 _one512 = _mm512_set1_ps(1.f); + __m512 _alpha512 = _mm512_set1_ps(alpha); + __m512 _lambda512 = _mm512_set1_ps(lambda); + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + + __m512 _pos = _mm512_max_ps(_zero512, _p); + __m512 _neg = _mm512_min_ps(_zero512, _p); + + __m512 _blob = exp512_ps(_neg); + _blob = _mm512_sub_ps(_blob, _one512); + _blob = _mm512_mul_ps(_alpha512, _blob); + _blob = _mm512_mul_ps(_lambda512, _mm512_add_ps(_pos, _blob)); + + _mm512_storeu_ps(ptr, _blob); + + ptr += 16; + } +#endif // __AVX512F__ + __m256 _zero256 = _mm256_setzero_ps(); + __m256 _one256 = _mm256_set1_ps(1.f); + __m256 _alpha256 = _mm256_set1_ps(alpha); + __m256 _lambda256 = _mm256_set1_ps(lambda); + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + + __m256 _pos = _mm256_max_ps(_zero256, _p); + __m256 _neg = _mm256_min_ps(_zero256, _p); + + __m256 _blob = exp256_ps(_neg); + _blob = _mm256_sub_ps(_blob, _one256); + _blob = _mm256_mul_ps(_alpha256, _blob); + _blob = _mm256_mul_ps(_lambda256, _mm256_add_ps(_pos, _blob)); + + _mm256_storeu_ps(ptr, _blob); + + ptr += 8; + } +#endif // __AVX__ + __m128 _zero128 = _mm_setzero_ps(); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _alpha128 = _mm_set1_ps(alpha); + __m128 _lambda128 = _mm_set1_ps(lambda); + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + + __m128 _pos = _mm_max_ps(_zero128, _p); + __m128 _neg = _mm_min_ps(_zero128, _p); + + __m128 _blob = exp_ps(_neg); + _blob = _mm_sub_ps(_blob, _one128); + _blob = _mm_mul_ps(_alpha128, _blob); + _blob = _mm_mul_ps(_lambda128, _mm_add_ps(_pos, _blob)); + + _mm_storeu_ps(ptr, _blob); + + ptr += 4; + } +#endif // __SSE2__ + float alphaxlambda = alpha * lambda; + for (; i < size; i++) + { + // y = lambda * ( max(0, x) + min(0, alpha * (exp(x) - 1)) ) + if (*ptr < 0) + *ptr = (expf(*ptr) - 1.f) * alphaxlambda; + else + *ptr = *ptr * lambda; + ptr++; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/selu_x86.h b/src/layer/x86/selu_x86.h new file mode 100644 index 00000000000..d7b5bf8a87e --- /dev/null +++ b/src/layer/x86/selu_x86.h @@ -0,0 +1,32 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_SELU_X86_H +#define LAYER_SELU_X86_H + +#include "selu.h" + +namespace ncnn { + +class SELU_x86 : virtual public SELU +{ +public: + SELU_x86(); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_SELU_X86_H diff --git a/src/layer/x86/shufflechannel_x86.cpp b/src/layer/x86/shufflechannel_x86.cpp new file mode 100644 index 00000000000..8afb22b2e2e --- /dev/null +++ b/src/layer/x86/shufflechannel_x86.cpp @@ -0,0 +1,772 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "shufflechannel_x86.h" + +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ + +namespace ncnn { + +ShuffleChannel_x86::ShuffleChannel_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + +int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int elembits = bottom_blob.elembits(); + if (elembits != 32) + { + NCNN_LOGE("Elembits = %d is not implemented yet.", elembits); + return -100; + } + + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + int size = w * h; + + int _group = reverse ? channels * elempack / group : group; + int channels_per_group = channels / _group; + + if (_group == 1) + { + top_blob = bottom_blob; + return 0; + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + __m512i _idxlo = _mm512_set_epi64( + 0x1700000007, 0x1600000006, + 0x1500000005, 0x1400000004, + 0x1300000003, 0x1200000002, + 0x1100000001, 0x1000000000); + __m512i _idxhi = _mm512_set_epi64( + 0x1f0000000f, 0x1e0000000e, + 0x1d0000000d, 0x1c0000000c, + 0x1b0000000b, 0x1a0000000a, + 0x1900000009, 0x1800000008); + + if (_group == 2 && channels % _group != 0) + { + top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group + q + 1); + float* outptr0 = top_blob.channel(q * 2); + float* outptr1 = top_blob.channel(q * 2 + 1); + + for (int i = 0; i < size; i++) + { + __m512 _p0 = _mm512_loadu_ps(ptr0); + __m512 _p1 = _mm512_loadu_ps(ptr1); + __m512 _p2 = _mm512_loadu_ps(ptr2); + + __m512 _p12 = _mm512_castsi512_ps( + _mm512_alignr_epi64(_mm512_castps_si512(_p2), _mm512_castps_si512(_p1), 4)); + + __m512 _lo = _mm512_permutex2var_ps(_p0, _idxlo, _p12); + __m512 _hi = _mm512_permutex2var_ps(_p0, _idxhi, _p12); + + _mm512_storeu_ps(outptr0, _lo); + _mm512_storeu_ps(outptr1, _hi); + + ptr0 += 16; + ptr1 += 16; + ptr2 += 16; + outptr0 += 16; + outptr1 += 16; + } + } + + // handle the last channel + { + const float* ptr0 = bottom_blob.channel(channels_per_group); + const float* ptr1 = bottom_blob.channel(channels_per_group * 2); + float* outptr = top_blob.channel(channels_per_group * 2); + + ptr1 += 8; + + for (int i = 0; i < size; i++) + { + __m256 _p0 = _mm256_loadu_ps(ptr0); + __m256 _p1 = _mm256_loadu_ps(ptr1); + + __m256 _lo = _mm256_unpacklo_ps(_p0, _p1); + __m256 _hi = _mm256_unpackhi_ps(_p0, _p1); + + __m256 _lo_ = _mm256_permute2f128_ps(_lo, _hi, 0x20); + __m256 _hi_ = _mm256_permute2f128_ps(_lo, _hi, 0x31); + + _mm256_storeu_ps(outptr, _lo_); + _mm256_storeu_ps(outptr + 8, _hi_); + + ptr0 += 16; + ptr1 += 16; + outptr += 16; + } + } + + return 0; + } + if (_group > 4 || channels % _group != 0) + { + // slow path for too large group or shuffle inside elempack + Option opt_pack = opt; + opt_pack.blob_allocator = opt.workspace_allocator; + + Mat bottom_blob_unpacked; + convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack); + + Mat top_blob_unpacked; + int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack); + if (ret != 0) + return ret; + + convert_packing(top_blob_unpacked, top_blob, elempack, opt); + + return 0; + } + + top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (_group == 2) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + float* outptr0 = top_blob.channel(q * 2); + float* outptr1 = top_blob.channel(q * 2 + 1); + + for (int i = 0; i < size; i++) + { + __m512 _p0 = _mm512_loadu_ps(ptr0); + __m512 _p1 = _mm512_loadu_ps(ptr1); + + __m512 _lo = _mm512_permutex2var_ps(_p0, _idxlo, _p1); + __m512 _hi = _mm512_permutex2var_ps(_p0, _idxhi, _p1); + + _mm512_storeu_ps(outptr0, _lo); + _mm512_storeu_ps(outptr1, _hi); + + ptr0 += 16; + ptr1 += 16; + outptr0 += 16; + outptr1 += 16; + } + } + + return 0; + } + if (_group == 3) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q); + float* outptr0 = top_blob.channel(q * 3); + float* outptr1 = top_blob.channel(q * 3 + 1); + float* outptr2 = top_blob.channel(q * 3 + 2); + + for (int i = 0; i < size; i++) + { + // TODO Naive implementation + /* + 0123456789abcdef 0gw1hx2iy3jz4kA5 + ghijklmnopqrstuv ---> lB6mC7nD8oE9pFaq + wxyzABCDEFGHIJKL GbrHcsIdtJeuKfvL + */ + + outptr0[0] = ptr0[0]; + outptr0[1] = ptr1[0]; + outptr0[2] = ptr2[0]; + outptr0[3] = ptr0[1]; + outptr0[4] = ptr1[1]; + outptr0[5] = ptr2[1]; + outptr0[6] = ptr0[2]; + outptr0[7] = ptr1[2]; + outptr0[8] = ptr2[2]; + outptr0[9] = ptr0[3]; + outptr0[10] = ptr1[3]; + outptr0[11] = ptr2[3]; + outptr0[12] = ptr0[4]; + outptr0[13] = ptr1[4]; + outptr0[14] = ptr2[4]; + outptr0[15] = ptr0[5]; + + outptr1[0] = ptr1[5]; + outptr1[1] = ptr2[5]; + outptr1[2] = ptr0[6]; + outptr1[3] = ptr1[6]; + outptr1[4] = ptr2[6]; + outptr1[5] = ptr0[7]; + outptr1[6] = ptr1[7]; + outptr1[7] = ptr2[7]; + outptr1[8] = ptr0[8]; + outptr1[9] = ptr1[8]; + outptr1[10] = ptr2[8]; + outptr1[11] = ptr0[9]; + outptr1[12] = ptr1[9]; + outptr1[13] = ptr2[9]; + outptr1[14] = ptr0[10]; + outptr1[15] = ptr1[10]; + + outptr2[0] = ptr2[10]; + outptr2[1] = ptr0[11]; + outptr2[2] = ptr1[11]; + outptr2[3] = ptr2[11]; + outptr2[4] = ptr0[12]; + outptr2[5] = ptr1[12]; + outptr2[6] = ptr2[12]; + outptr2[7] = ptr0[13]; + outptr2[8] = ptr1[13]; + outptr2[9] = ptr2[13]; + outptr2[10] = ptr0[14]; + outptr2[11] = ptr1[14]; + outptr2[12] = ptr2[14]; + outptr2[13] = ptr0[15]; + outptr2[14] = ptr1[15]; + outptr2[15] = ptr2[15]; + + ptr0 += 16; + ptr1 += 16; + ptr2 += 16; + outptr0 += 16; + outptr1 += 16; + outptr2 += 16; + } + } + + return 0; + } + if (_group == 4) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q); + const float* ptr3 = bottom_blob.channel(channels_per_group * 3 + q); + float* outptr0 = top_blob.channel(q * 4); + float* outptr1 = top_blob.channel(q * 4 + 1); + float* outptr2 = top_blob.channel(q * 4 + 2); + float* outptr3 = top_blob.channel(q * 4 + 3); + + for (int i = 0; i < size; i++) + { + __m512 _p0 = _mm512_loadu_ps(ptr0); + __m512 _p1 = _mm512_loadu_ps(ptr1); + __m512 _p2 = _mm512_loadu_ps(ptr2); + __m512 _p3 = _mm512_loadu_ps(ptr3); + + __m512 _lo02 = _mm512_permutex2var_ps(_p0, _idxlo, _p2); + __m512 _hi02 = _mm512_permutex2var_ps(_p0, _idxhi, _p2); + __m512 _lo13 = _mm512_permutex2var_ps(_p1, _idxlo, _p3); + __m512 _hi13 = _mm512_permutex2var_ps(_p1, _idxhi, _p3); + + __m512 _lolo = _mm512_permutex2var_ps(_lo02, _idxlo, _lo13); + __m512 _lohi = _mm512_permutex2var_ps(_lo02, _idxhi, _lo13); + __m512 _hilo = _mm512_permutex2var_ps(_hi02, _idxlo, _hi13); + __m512 _hihi = _mm512_permutex2var_ps(_hi02, _idxhi, _hi13); + + _mm512_storeu_ps(outptr0, _lolo); + _mm512_storeu_ps(outptr1, _lohi); + _mm512_storeu_ps(outptr2, _hilo); + _mm512_storeu_ps(outptr3, _hihi); + + ptr0 += 16; + ptr1 += 16; + ptr2 += 16; + ptr3 += 16; + outptr0 += 16; + outptr1 += 16; + outptr2 += 16; + outptr3 += 16; + } + } + + return 0; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + if (_group == 2 && channels % _group != 0) + { + top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group + q + 1); + float* outptr0 = top_blob.channel(q * 2); + float* outptr1 = top_blob.channel(q * 2 + 1); + + ptr1 += 4; + + for (int i = 0; i < size; i++) + { + __m256 _p0 = _mm256_loadu_ps(ptr0); + + __m256 _p1 = _mm256_castps128_ps256(_mm_loadu_ps(ptr1)); + _p1 = _mm256_insertf128_ps(_p1, _mm_loadu_ps(ptr2), 1); + + __m256 _lo = _mm256_unpacklo_ps(_p0, _p1); + __m256 _hi = _mm256_unpackhi_ps(_p0, _p1); + + __m256 _lo_ = _mm256_permute2f128_ps(_lo, _hi, 0x20); + __m256 _hi_ = _mm256_permute2f128_ps(_lo, _hi, 0x31); + + _mm256_storeu_ps(outptr0, _lo_); + _mm256_storeu_ps(outptr1, _hi_); + + ptr0 += 8; + ptr1 += 8; + ptr2 += 8; + outptr0 += 8; + outptr1 += 8; + } + } + + // handle the last channel + { + const float* ptr0 = bottom_blob.channel(channels_per_group); + const float* ptr1 = bottom_blob.channel(channels_per_group * 2); + float* outptr = top_blob.channel(channels_per_group * 2); + + ptr1 += 4; + + for (int i = 0; i < size; i++) + { + __m128 _p0 = _mm_loadu_ps(ptr0); + __m128 _p1 = _mm_loadu_ps(ptr1); + + __m128 _lo = _mm_unpacklo_ps(_p0, _p1); + __m128 _hi = _mm_unpackhi_ps(_p0, _p1); + + _mm_storeu_ps(outptr, _lo); + _mm_storeu_ps(outptr + 4, _hi); + + ptr0 += 8; + ptr1 += 8; + outptr += 8; + } + } + + return 0; + } + if (_group > 4 || channels % _group != 0) + { + // slow path for too large group or shuffle inside elempack + Option opt_pack = opt; + opt_pack.blob_allocator = opt.workspace_allocator; + + Mat bottom_blob_unpacked; + convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack); + + Mat top_blob_unpacked; + int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack); + if (ret != 0) + return ret; + + convert_packing(top_blob_unpacked, top_blob, elempack, opt); + + return 0; + } + + top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (_group == 2) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + float* outptr0 = top_blob.channel(q * 2); + float* outptr1 = top_blob.channel(q * 2 + 1); + + for (int i = 0; i < size; i++) + { + __m256 _p0 = _mm256_loadu_ps(ptr0); + __m256 _p1 = _mm256_loadu_ps(ptr1); + + __m256 _lo = _mm256_unpacklo_ps(_p0, _p1); + __m256 _hi = _mm256_unpackhi_ps(_p0, _p1); + + __m256 _lo_ = _mm256_permute2f128_ps(_lo, _hi, 0x20); + __m256 _hi_ = _mm256_permute2f128_ps(_lo, _hi, 0x31); + + _mm256_storeu_ps(outptr0, _lo_); + _mm256_storeu_ps(outptr1, _hi_); + + ptr0 += 8; + ptr1 += 8; + outptr0 += 8; + outptr1 += 8; + } + } + + return 0; + } + if (_group == 3) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q); + float* outptr0 = top_blob.channel(q * 3); + float* outptr1 = top_blob.channel(q * 3 + 1); + float* outptr2 = top_blob.channel(q * 3 + 2); + + for (int i = 0; i < size; i++) + { + // TODO figure out a faster way + /* + 01234567 08g19h2a + 89abcdef ---> i3bj4ck5 + ghijklmn dl6em7fn + */ + + __m256 _p0 = _mm256_loadu_ps(ptr0); // 01234567 + __m256 _p1 = _mm256_loadu_ps(ptr1); // 89abcdef + __m256 _p2 = _mm256_loadu_ps(ptr2); // ghijklmn + + __m256 _08194c5d = _mm256_unpacklo_ps(_p0, _p1); + __m256 _2a3b6e7f = _mm256_unpackhi_ps(_p0, _p1); + __m256 _8g9hckdl = _mm256_unpacklo_ps(_p1, _p2); + __m256 _aibjemfn = _mm256_unpackhi_ps(_p1, _p2); + __m256 _0g1h4k5l = _mm256_unpacklo_ps(_p0, _p2); + __m256 _2i3j6m7n = _mm256_unpackhi_ps(_p0, _p2); + + __m256 _i3g1m7k5 = _mm256_shuffle_ps(_2i3j6m7n, _0g1h4k5l, _MM_SHUFFLE(2, 1, 2, 1)); + + __m256 _9h2adl6e = _mm256_shuffle_ps(_8g9hckdl, _2a3b6e7f, _MM_SHUFFLE(1, 0, 3, 2)); + __m256 _08g14ck5 = _mm256_shuffle_ps(_08194c5d, _i3g1m7k5, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _i3bjm7fn = _mm256_shuffle_ps(_i3g1m7k5, _aibjemfn, _MM_SHUFFLE(3, 2, 1, 0)); + + __m256 _08g19h2a = _mm256_permute2f128_ps(_08g14ck5, _9h2adl6e, 0x20); // 0 2 + __m256 _i3bj4ck5 = _mm256_permute2f128_ps(_i3bjm7fn, _08g14ck5, 0x30); // 0 3 + __m256 _dl6em7fn = _mm256_permute2f128_ps(_9h2adl6e, _i3bjm7fn, 0x31); // 1 3 + + _mm256_storeu_ps(outptr0, _08g19h2a); + _mm256_storeu_ps(outptr1, _i3bj4ck5); + _mm256_storeu_ps(outptr2, _dl6em7fn); + + ptr0 += 8; + ptr1 += 8; + ptr2 += 8; + outptr0 += 8; + outptr1 += 8; + outptr2 += 8; + } + } + + return 0; + } + if (_group == 4) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q); + const float* ptr3 = bottom_blob.channel(channels_per_group * 3 + q); + float* outptr0 = top_blob.channel(q * 4); + float* outptr1 = top_blob.channel(q * 4 + 1); + float* outptr2 = top_blob.channel(q * 4 + 2); + float* outptr3 = top_blob.channel(q * 4 + 3); + + for (int i = 0; i < size; i++) + { + __m256 _p0 = _mm256_loadu_ps(ptr0); + __m256 _p1 = _mm256_loadu_ps(ptr1); + __m256 _p2 = _mm256_loadu_ps(ptr2); + __m256 _p3 = _mm256_loadu_ps(ptr3); + + __m256 _lo02 = _mm256_unpacklo_ps(_p0, _p2); + __m256 _hi02 = _mm256_unpackhi_ps(_p0, _p2); + __m256 _lo13 = _mm256_unpacklo_ps(_p1, _p3); + __m256 _hi13 = _mm256_unpackhi_ps(_p1, _p3); + + __m256 _lolo = _mm256_unpacklo_ps(_lo02, _lo13); + __m256 _lohi = _mm256_unpackhi_ps(_lo02, _lo13); + __m256 _hilo = _mm256_unpacklo_ps(_hi02, _hi13); + __m256 _hihi = _mm256_unpackhi_ps(_hi02, _hi13); + + __m256 _lolo_ = _mm256_permute2f128_ps(_lolo, _lohi, 0x20); + __m256 _lohi_ = _mm256_permute2f128_ps(_hilo, _hihi, 0x20); + __m256 _hilo_ = _mm256_permute2f128_ps(_lolo, _lohi, 0x31); + __m256 _hihi_ = _mm256_permute2f128_ps(_hilo, _hihi, 0x31); + + _mm256_storeu_ps(outptr0, _lolo_); + _mm256_storeu_ps(outptr1, _lohi_); + _mm256_storeu_ps(outptr2, _hilo_); + _mm256_storeu_ps(outptr3, _hihi_); + + ptr0 += 8; + ptr1 += 8; + ptr2 += 8; + ptr3 += 8; + outptr0 += 8; + outptr1 += 8; + outptr2 += 8; + outptr3 += 8; + } + } + + return 0; + } + } +#endif // __AVX__ + if (elempack == 4) + { + if (_group == 2 && channels % _group != 0) + { + top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group + q + 1); + float* outptr0 = top_blob.channel(q * 2); + float* outptr1 = top_blob.channel(q * 2 + 1); + + for (int i = 0; i < size; i++) + { + __m128 _p0 = _mm_loadu_ps(ptr0); + __m128 _p1 = _mm_loadu_ps(ptr1); + __m128 _p2 = _mm_loadu_ps(ptr2); + + __m128 _p12 = _mm_shuffle_ps(_p1, _p2, _MM_SHUFFLE(1, 0, 3, 2)); + + __m128 _lo = _mm_unpacklo_ps(_p0, _p12); + __m128 _hi = _mm_unpackhi_ps(_p0, _p12); + + _mm_storeu_ps(outptr0, _lo); + _mm_storeu_ps(outptr1, _hi); + + ptr0 += 4; + ptr1 += 4; + ptr2 += 4; + outptr0 += 4; + outptr1 += 4; + } + } + + // handle the last channel + { + const float* ptr0 = bottom_blob.channel(channels_per_group); + const float* ptr1 = bottom_blob.channel(channels_per_group * 2); + float* outptr = top_blob.channel(channels_per_group * 2); + + ptr1 += 2; + + for (int i = 0; i < size; i++) + { + __m128 _p0 = _mm_loadu_ps(ptr0); + __m128 _p1 = _mm_loadu_ps(ptr1); + + __m128 _lo = _mm_unpacklo_ps(_p0, _p1); + + _mm_storeu_ps(outptr, _lo); + + ptr0 += 4; + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (_group > 4 || channels % _group != 0) + { + // slow path for too large group or shuffle inside elempack + Option opt_pack = opt; + opt_pack.blob_allocator = opt.workspace_allocator; + + Mat bottom_blob_unpacked; + convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack); + + Mat top_blob_unpacked; + int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack); + if (ret != 0) + return ret; + + convert_packing(top_blob_unpacked, top_blob, elempack, opt); + + return 0; + } + + top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (_group == 2) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + float* outptr0 = top_blob.channel(q * 2); + float* outptr1 = top_blob.channel(q * 2 + 1); + + for (int i = 0; i < size; i++) + { + __m128 _p0 = _mm_loadu_ps(ptr0); + __m128 _p1 = _mm_loadu_ps(ptr1); + + __m128 _lo = _mm_unpacklo_ps(_p0, _p1); + __m128 _hi = _mm_unpackhi_ps(_p0, _p1); + + _mm_storeu_ps(outptr0, _lo); + _mm_storeu_ps(outptr1, _hi); + + ptr0 += 4; + ptr1 += 4; + outptr0 += 4; + outptr1 += 4; + } + } + + return 0; + } + if (_group == 3) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q); + float* outptr0 = top_blob.channel(q * 3); + float* outptr1 = top_blob.channel(q * 3 + 1); + float* outptr2 = top_blob.channel(q * 3 + 2); + + for (int i = 0; i < size; i++) + { + __m128 _p0 = _mm_loadu_ps(ptr0); + __m128 _p1 = _mm_loadu_ps(ptr1); + __m128 _p2 = _mm_loadu_ps(ptr2); + + __m128 _0415 = _mm_unpacklo_ps(_p0, _p1); + __m128 _2637 = _mm_unpackhi_ps(_p0, _p1); + __m128 _4859 = _mm_unpacklo_ps(_p1, _p2); + __m128 _6a7b = _mm_unpackhi_ps(_p1, _p2); + + __m128 _138a = _mm_shuffle_ps(_p0, _p2, _MM_SHUFFLE(2, 0, 3, 1)); + + __m128 _0481 = _mm_shuffle_ps(_0415, _138a, _MM_SHUFFLE(0, 2, 1, 0)); + __m128 _5926 = _mm_shuffle_ps(_4859, _2637, _MM_SHUFFLE(1, 0, 3, 2)); + __m128 _a37b = _mm_shuffle_ps(_138a, _6a7b, _MM_SHUFFLE(3, 2, 1, 3)); + + _mm_storeu_ps(outptr0, _0481); + _mm_storeu_ps(outptr1, _5926); + _mm_storeu_ps(outptr2, _a37b); + + ptr0 += 4; + ptr1 += 4; + ptr2 += 4; + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + } + } + + return 0; + } + if (_group == 4) + { + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr0 = bottom_blob.channel(q); + const float* ptr1 = bottom_blob.channel(channels_per_group + q); + const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q); + const float* ptr3 = bottom_blob.channel(channels_per_group * 3 + q); + float* outptr0 = top_blob.channel(q * 4); + float* outptr1 = top_blob.channel(q * 4 + 1); + float* outptr2 = top_blob.channel(q * 4 + 2); + float* outptr3 = top_blob.channel(q * 4 + 3); + + for (int i = 0; i < size; i++) + { + __m128 _p0 = _mm_loadu_ps(ptr0); + __m128 _p1 = _mm_loadu_ps(ptr1); + __m128 _p2 = _mm_loadu_ps(ptr2); + __m128 _p3 = _mm_loadu_ps(ptr3); + + __m128 _lo02 = _mm_unpacklo_ps(_p0, _p2); + __m128 _hi02 = _mm_unpackhi_ps(_p0, _p2); + __m128 _lo13 = _mm_unpacklo_ps(_p1, _p3); + __m128 _hi13 = _mm_unpackhi_ps(_p1, _p3); + + __m128 _lolo = _mm_unpacklo_ps(_lo02, _lo13); + __m128 _lohi = _mm_unpackhi_ps(_lo02, _lo13); + __m128 _hilo = _mm_unpacklo_ps(_hi02, _hi13); + __m128 _hihi = _mm_unpackhi_ps(_hi02, _hi13); + + _mm_storeu_ps(outptr0, _lolo); + _mm_storeu_ps(outptr1, _lohi); + _mm_storeu_ps(outptr2, _hilo); + _mm_storeu_ps(outptr3, _hihi); + + ptr0 += 4; + ptr1 += 4; + ptr2 += 4; + ptr3 += 4; + outptr0 += 4; + outptr1 += 4; + outptr2 += 4; + outptr3 += 4; + } + } + + return 0; + } + } +#endif // __SSE2__ + + return ShuffleChannel::forward(bottom_blob, top_blob, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/shufflechannel_x86.h b/src/layer/x86/shufflechannel_x86.h new file mode 100644 index 00000000000..6adca483c17 --- /dev/null +++ b/src/layer/x86/shufflechannel_x86.h @@ -0,0 +1,32 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_SHUFFLECHANNEL_X86_H +#define LAYER_SHUFFLECHANNEL_X86_H + +#include "shufflechannel.h" + +namespace ncnn { + +class ShuffleChannel_x86 : virtual public ShuffleChannel +{ +public: + ShuffleChannel_x86(); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_SHUFFLECHANNEL_X86_H diff --git a/src/layer/x86/sigmoid_x86.cpp b/src/layer/x86/sigmoid_x86.cpp index ed55d20859b..0cf44f84591 100644 --- a/src/layer/x86/sigmoid_x86.cpp +++ b/src/layer/x86/sigmoid_x86.cpp @@ -26,8 +26,6 @@ #endif // __AVX__ #endif // __SSE2__ -#include - namespace ncnn { Sigmoid_x86::Sigmoid_x86() diff --git a/src/layer/x86/softmax_x86.cpp b/src/layer/x86/softmax_x86.cpp index 07e7c535af2..41e5bd25d2e 100644 --- a/src/layer/x86/softmax_x86.cpp +++ b/src/layer/x86/softmax_x86.cpp @@ -15,7 +15,6 @@ #include "softmax_x86.h" #include -#include #if __SSE2__ #include diff --git a/src/layer/x86/swish_x86.cpp b/src/layer/x86/swish_x86.cpp index 73a074fb9ad..d8ae2695016 100644 --- a/src/layer/x86/swish_x86.cpp +++ b/src/layer/x86/swish_x86.cpp @@ -26,8 +26,6 @@ #endif // __AVX__ #endif // __SSE2__ -#include - namespace ncnn { Swish_x86::Swish_x86() diff --git a/src/layer/x86/tanh_x86.cpp b/src/layer/x86/tanh_x86.cpp index 2cebf19c2d3..bf94450e9fb 100644 --- a/src/layer/x86/tanh_x86.cpp +++ b/src/layer/x86/tanh_x86.cpp @@ -16,8 +16,6 @@ #include "x86_activation.h" -#include - namespace ncnn { TanH_x86::TanH_x86() diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 8629ab2093b..1ccd50d601a 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -14,9 +14,8 @@ #include "unaryop_x86.h" -#include +// #include #include -#include #if __SSE2__ #include diff --git a/src/layer/x86/x86_activation.h b/src/layer/x86/x86_activation.h index b02b8ee9a46..691bc65ee4c 100644 --- a/src/layer/x86/x86_activation.h +++ b/src/layer/x86/x86_activation.h @@ -15,7 +15,6 @@ #ifndef X86_ACTIVATION_H #define X86_ACTIVATION_H -#include #include "mat.h" #include "fused_activation.h" #include "x86_usability.h" diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index c9551330ff6..9cb826fa2b1 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -15,7 +15,6 @@ #ifndef X86_USABILITY_H #define X86_USABILITY_H -#include #if __SSE2__ #include #if __SSE4_1__ @@ -42,6 +41,83 @@ static NCNN_FORCEINLINE signed char float2int8(float v) } #if __SSE2__ +static NCNN_FORCEINLINE void transpose4x8_epi32(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) +{ + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi32(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi32(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi32(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi32(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi32(_r6, _r7); + + _r0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _r1 = _mm_unpacklo_epi64(_tmp4, _tmp6); + _r2 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _r3 = _mm_unpackhi_epi64(_tmp4, _tmp6); + _r4 = _mm_unpacklo_epi64(_tmp1, _tmp3); + _r5 = _mm_unpacklo_epi64(_tmp5, _tmp7); + _r6 = _mm_unpackhi_epi64(_tmp1, _tmp3); + _r7 = _mm_unpackhi_epi64(_tmp5, _tmp7); +} + +static NCNN_FORCEINLINE void transpose4x4_epi32(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi32(_r2, _r3); + + _r0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi64(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi64(_tmp1, _tmp3); +} + +static NCNN_FORCEINLINE void transpose8x8_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) +{ + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7); + + __m128i _tmp8 = _mm_unpacklo_epi32(_tmp0, _tmp2); + __m128i _tmp9 = _mm_unpackhi_epi32(_tmp0, _tmp2); + __m128i _tmpa = _mm_unpacklo_epi32(_tmp1, _tmp3); + __m128i _tmpb = _mm_unpackhi_epi32(_tmp1, _tmp3); + __m128i _tmpc = _mm_unpacklo_epi32(_tmp4, _tmp6); + __m128i _tmpd = _mm_unpackhi_epi32(_tmp4, _tmp6); + __m128i _tmpe = _mm_unpacklo_epi32(_tmp5, _tmp7); + __m128i _tmpf = _mm_unpackhi_epi32(_tmp5, _tmp7); + + _r0 = _mm_unpacklo_epi64(_tmp8, _tmpc); + _r1 = _mm_unpackhi_epi64(_tmp8, _tmpc); + _r2 = _mm_unpacklo_epi64(_tmp9, _tmpd); + _r3 = _mm_unpackhi_epi64(_tmp9, _tmpd); + _r4 = _mm_unpacklo_epi64(_tmpa, _tmpe); + _r5 = _mm_unpackhi_epi64(_tmpa, _tmpe); + _r6 = _mm_unpacklo_epi64(_tmpb, _tmpf); + _r7 = _mm_unpackhi_epi64(_tmpb, _tmpf); +} + +static NCNN_FORCEINLINE void transpose8x4_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + + _r0 = _mm_unpacklo_epi32(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi32(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi32(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi32(_tmp1, _tmp3); +} + static NCNN_FORCEINLINE float _mm_reduce_add_ps(__m128 x128) { const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); @@ -196,6 +272,14 @@ static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128 { return _mm_sub_ps(_c, _mm_mul_ps(_a, _b)); } +static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) +{ + return _mm_sub_ps(_mm_mul_ps(_a, _b), _c); +} +static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) +{ + return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1))); +} #else static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) { @@ -206,6 +290,14 @@ static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128 // return -a * b + c return _mm_fnmadd_ps(_a, _b, _c); } +static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) +{ + return _mm_fmsub_ps(_a, _b, _c); +} +static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) +{ + return _mm_fnmsub_ps(_a, _b, _c); +} #endif // !__FMA__ #if __AVX__ @@ -218,9 +310,18 @@ static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m { return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b)); } +static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +{ + return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c); +} +static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +{ + return _mm256_sub_ps(_c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1))); +} #else static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) { + // return a * b + c return _mm256_fmadd_ps(_a, _b, _c); } static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) @@ -228,6 +329,16 @@ static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m // return -a * b + c return _mm256_fnmadd_ps(_a, _b, _c); } +static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +{ + // return a * b - c + return _mm256_fmsub_ps(_a, _b, _c); +} +static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +{ + // return -(a * b) - c + return _mm256_fnmsub_ps(_a, _b, _c); +} #endif static NCNN_FORCEINLINE __m256 _mm256_fmadd_1_ps(const __m256& a, const __m256& b, float c) @@ -341,34 +452,161 @@ static NCNN_FORCEINLINE void transpose8x2_ps(__m256& _r0, __m256& _r1) _r1 = _mm256_permute2f128_ps(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); } -static NCNN_FORCEINLINE void transpose8x8_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) +static NCNN_FORCEINLINE void transpose2x8_ps(__m256& _r0, __m256& _r1) { - __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); - __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); - __m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5); - __m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7); + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 3, 0, 1)); - __m128i _tmp8 = _mm_unpacklo_epi32(_tmp0, _tmp2); - __m128i _tmp9 = _mm_unpackhi_epi32(_tmp0, _tmp2); - __m128i _tmpa = _mm_unpacklo_epi32(_tmp1, _tmp3); - __m128i _tmpb = _mm_unpackhi_epi32(_tmp1, _tmp3); - __m128i _tmpc = _mm_unpacklo_epi32(_tmp4, _tmp6); - __m128i _tmpd = _mm_unpackhi_epi32(_tmp4, _tmp6); - __m128i _tmpe = _mm_unpacklo_epi32(_tmp5, _tmp7); - __m128i _tmpf = _mm_unpackhi_epi32(_tmp5, _tmp7); + _r0 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); +} - _r0 = _mm_unpacklo_epi64(_tmp8, _tmpc); - _r1 = _mm_unpackhi_epi64(_tmp8, _tmpc); - _r2 = _mm_unpacklo_epi64(_tmp9, _tmpd); - _r3 = _mm_unpackhi_epi64(_tmp9, _tmpd); - _r4 = _mm_unpacklo_epi64(_tmpa, _tmpe); - _r5 = _mm_unpackhi_epi64(_tmpa, _tmpe); - _r6 = _mm_unpacklo_epi64(_tmpb, _tmpf); - _r7 = _mm_unpackhi_epi64(_tmpb, _tmpf); +static NCNN_FORCEINLINE void transpose3x8_ps(__m256& _r0, __m256& _r1, __m256& _r2) +{ + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 3, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 2, 0, 1)); + __m256 _tmp2 = _mm256_permute2f128_ps(_r1, _r2, _MM_SHUFFLE(0, 3, 0, 0)); + + __m256 _tmp4 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(1, 0, 2, 1)); + __m256 _tmp5 = _mm256_shuffle_ps(_tmp1, _tmp2, _MM_SHUFFLE(2, 1, 3, 2)); + + _r0 = _mm256_shuffle_ps(_tmp0, _tmp5, _MM_SHUFFLE(2, 0, 3, 0)); + _r1 = _mm256_shuffle_ps(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _r2 = _mm256_shuffle_ps(_tmp4, _tmp2, _MM_SHUFFLE(3, 0, 3, 1)); +} + +static NCNN_FORCEINLINE void transpose8x6_ps(__m256& _r0, __m256& _r1, __m256& _r2, __m256& _r3, __m256& _r4, __m256& _r5) +{ + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + + __m256 _tmp6 = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmp7 = _mm256_shuffle_ps(_tmp4, _tmp0, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmp8 = _mm256_shuffle_ps(_tmp2, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmp9 = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpa = _mm256_shuffle_ps(_tmp5, _tmp1, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpb = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmp8, _tmp9, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpa, _tmpb, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + _r4 = _mm256_permute2f128_ps(_tmp8, _tmp9, _MM_SHUFFLE(0, 3, 0, 1)); + _r5 = _mm256_permute2f128_ps(_tmpa, _tmpb, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static NCNN_FORCEINLINE void transpose8x11_ps(__m256& _r0, __m256& _r1, __m256& _r2, __m256& _r3, __m256& _r4, __m256& _r5, __m256& _r6, __m256& _r7, __m256& _r8, __m256& _r9, __m256& _ra) +{ + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + __m256 _tmp8 = _mm256_unpacklo_ps(_r8, _r9); + __m256 _tmp9 = _mm256_unpackhi_ps(_r8, _r9); + __m256 _tmpa = _mm256_unpacklo_ps(_ra, _r0); + __m256 _tmpb = _mm256_shuffle_ps(_ra, _tmp1, _MM_SHUFFLE(3, 2, 1, 2)); + __m256 _tmpc = _mm256_unpacklo_ps(_r1, _r2); + __m256 _tmpd = _mm256_unpackhi_ps(_r1, _r2); + __m256 _tmpe = _mm256_unpacklo_ps(_r3, _r4); + __m256 _tmpf = _mm256_unpackhi_ps(_r3, _r4); + __m256 _tmpg = _mm256_unpacklo_ps(_r5, _r6); + __m256 _tmph = _mm256_unpackhi_ps(_r5, _r6); + __m256 _tmpi = _mm256_unpacklo_ps(_r7, _r8); + __m256 _tmpj = _mm256_unpackhi_ps(_r7, _r8); + __m256 _tmpk = _mm256_unpacklo_ps(_r9, _ra); + __m256 _tmpl = _mm256_unpackhi_ps(_r9, _ra); + + __m256 _tmpm = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpn = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpo = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(3, 0, 1, 0)); + __m256 _tmpp = _mm256_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpq = _mm256_shuffle_ps(_tmpg, _tmpi, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpr = _mm256_shuffle_ps(_tmpk, _tmp1, _MM_SHUFFLE(1, 0, 3, 2)); + __m256 _tmps = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpt = _mm256_shuffle_ps(_tmp7, _tmp9, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpu = _mm256_shuffle_ps(_tmpb, _tmpd, _MM_SHUFFLE(3, 2, 2, 0)); + __m256 _tmpv = _mm256_shuffle_ps(_tmpf, _tmph, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpw = _mm256_shuffle_ps(_tmpj, _tmpl, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2f128_ps(_tmpw, _tmpm, _MM_SHUFFLE(0, 3, 0, 0)); + _r6 = _mm256_permute2f128_ps(_tmpn, _tmpo, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2f128_ps(_tmpp, _tmpq, _MM_SHUFFLE(0, 3, 0, 1)); + _r8 = _mm256_permute2f128_ps(_tmpr, _tmps, _MM_SHUFFLE(0, 3, 0, 1)); + _r9 = _mm256_permute2f128_ps(_tmpt, _tmpu, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2f128_ps(_tmpv, _tmpw, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static void transpose8x18_ps(__m256& _r0, __m256& _r1, __m256& _r2, __m256& _r3, __m256& _r4, __m256& _r5, __m256& _r6, __m256& _r7, __m256& _r8, __m256& _r9, __m256& _ra, __m256& _rb, __m256& _rc, __m256& _rd, __m256& _re, __m256& _rf, __m256& _rg, __m256& _rh) +{ + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + __m256 _tmp8 = _mm256_unpacklo_ps(_r8, _r9); + __m256 _tmp9 = _mm256_unpackhi_ps(_r8, _r9); + __m256 _tmpa = _mm256_unpacklo_ps(_ra, _rb); + __m256 _tmpb = _mm256_unpackhi_ps(_ra, _rb); + __m256 _tmpc = _mm256_unpacklo_ps(_rc, _rd); + __m256 _tmpd = _mm256_unpackhi_ps(_rc, _rd); + __m256 _tmpe = _mm256_unpacklo_ps(_re, _rf); + __m256 _tmpf = _mm256_unpackhi_ps(_re, _rf); + __m256 _tmpg = _mm256_unpacklo_ps(_rg, _rh); + __m256 _tmph = _mm256_unpackhi_ps(_rg, _rh); + + __m256 _tmpi = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpj = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpk = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpl = _mm256_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpm = _mm256_shuffle_ps(_tmpg, _tmp0, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpn = _mm256_shuffle_ps(_tmp2, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpo = _mm256_shuffle_ps(_tmp6, _tmp8, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpp = _mm256_shuffle_ps(_tmpa, _tmpc, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpq = _mm256_shuffle_ps(_tmpe, _tmpg, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpr = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmps = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpt = _mm256_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpu = _mm256_shuffle_ps(_tmpd, _tmpf, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpv = _mm256_shuffle_ps(_tmph, _tmp1, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpw = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpx = _mm256_shuffle_ps(_tmp7, _tmp9, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpy = _mm256_shuffle_ps(_tmpb, _tmpd, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpz = _mm256_shuffle_ps(_tmpf, _tmph, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmpi, _tmpj, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmpk, _tmpl, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 2, 0, 0)); + _r6 = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 2, 0, 0)); + _r7 = _mm256_permute2f128_ps(_tmpw, _tmpx, _MM_SHUFFLE(0, 2, 0, 0)); + _r8 = _mm256_permute2f128_ps(_tmpy, _tmpz, _MM_SHUFFLE(0, 2, 0, 0)); + _r9 = _mm256_permute2f128_ps(_tmpi, _tmpj, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2f128_ps(_tmpk, _tmpl, _MM_SHUFFLE(0, 3, 0, 1)); + _rb = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 3, 0, 1)); + _rc = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 3, 0, 1)); + _rd = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 3, 0, 1)); + _re = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 3, 0, 1)); + _rf = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 3, 0, 1)); + _rg = _mm256_permute2f128_ps(_tmpw, _tmpx, _MM_SHUFFLE(0, 3, 0, 1)); + _rh = _mm256_permute2f128_ps(_tmpy, _tmpz, _MM_SHUFFLE(0, 3, 0, 1)); } static NCNN_FORCEINLINE __m256 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3, __m256& v4, __m256& v5, __m256& v6, __m256& v7) @@ -598,6 +836,55 @@ static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256& return _v; } +#if __AVX2__ +static NCNN_FORCEINLINE void transpose8x2_epi32(__m256i& _r0, __m256i& _r1) +{ + __m256i _tmp0 = _mm256_unpacklo_epi32(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi32(_r0, _r1); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static NCNN_FORCEINLINE void transpose16x8_epi16(__m256i& _r0, __m256i& _r1, __m256i& _r2, __m256i& _r3, __m256i& _r4, __m256i& _r5, __m256i& _r6, __m256i& _r7) +{ + __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); + __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); + __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); + __m256i _tmp4 = _mm256_unpacklo_epi16(_r4, _r5); + __m256i _tmp5 = _mm256_unpackhi_epi16(_r4, _r5); + __m256i _tmp6 = _mm256_unpacklo_epi16(_r6, _r7); + __m256i _tmp7 = _mm256_unpackhi_epi16(_r6, _r7); + + __m256i _tmpg = _mm256_unpacklo_epi32(_tmp0, _tmp2); + __m256i _tmph = _mm256_unpackhi_epi32(_tmp0, _tmp2); + __m256i _tmpi = _mm256_unpacklo_epi32(_tmp1, _tmp3); + __m256i _tmpj = _mm256_unpackhi_epi32(_tmp1, _tmp3); + __m256i _tmpk = _mm256_unpacklo_epi32(_tmp4, _tmp6); + __m256i _tmpl = _mm256_unpackhi_epi32(_tmp4, _tmp6); + __m256i _tmpm = _mm256_unpacklo_epi32(_tmp5, _tmp7); + __m256i _tmpn = _mm256_unpackhi_epi32(_tmp5, _tmp7); + + _tmp0 = _mm256_unpacklo_epi64(_tmpg, _tmpk); + _tmp1 = _mm256_unpackhi_epi64(_tmpg, _tmpk); + _tmp2 = _mm256_unpacklo_epi64(_tmph, _tmpl); + _tmp3 = _mm256_unpackhi_epi64(_tmph, _tmpl); + _tmp4 = _mm256_unpacklo_epi64(_tmpi, _tmpm); + _tmp5 = _mm256_unpackhi_epi64(_tmpi, _tmpm); + _tmp6 = _mm256_unpacklo_epi64(_tmpj, _tmpn); + _tmp7 = _mm256_unpackhi_epi64(_tmpj, _tmpn); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); + _r5 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); + _r6 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); +} + #if __AVX512F__ static NCNN_FORCEINLINE void transpose16x16_ps(__m512& _r0, __m512& _r1, __m512& _r2, __m512& _r3, __m512& _r4, __m512& _r5, __m512& _r6, __m512& _r7, __m512& _r8, __m512& _r9, __m512& _ra, __m512& _rb, __m512& _rc, __m512& _rd, __m512& _re, __m512& _rf) @@ -928,45 +1215,6 @@ static NCNN_FORCEINLINE void transpose16x16_epi16(__m256i& _r0, __m256i& _r1, __ _rf = _mm256_permute2x128_si256(_tmp7, _tmpf, _MM_SHUFFLE(0, 3, 0, 1)); } -static NCNN_FORCEINLINE void transpose16x8_epi16(__m256i& _r0, __m256i& _r1, __m256i& _r2, __m256i& _r3, __m256i& _r4, __m256i& _r5, __m256i& _r6, __m256i& _r7) -{ - __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); - __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); - __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); - __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); - __m256i _tmp4 = _mm256_unpacklo_epi16(_r4, _r5); - __m256i _tmp5 = _mm256_unpackhi_epi16(_r4, _r5); - __m256i _tmp6 = _mm256_unpacklo_epi16(_r6, _r7); - __m256i _tmp7 = _mm256_unpackhi_epi16(_r6, _r7); - - __m256i _tmpg = _mm256_unpacklo_epi32(_tmp0, _tmp2); - __m256i _tmph = _mm256_unpackhi_epi32(_tmp0, _tmp2); - __m256i _tmpi = _mm256_unpacklo_epi32(_tmp1, _tmp3); - __m256i _tmpj = _mm256_unpackhi_epi32(_tmp1, _tmp3); - __m256i _tmpk = _mm256_unpacklo_epi32(_tmp4, _tmp6); - __m256i _tmpl = _mm256_unpackhi_epi32(_tmp4, _tmp6); - __m256i _tmpm = _mm256_unpacklo_epi32(_tmp5, _tmp7); - __m256i _tmpn = _mm256_unpackhi_epi32(_tmp5, _tmp7); - - _tmp0 = _mm256_unpacklo_epi64(_tmpg, _tmpk); - _tmp1 = _mm256_unpackhi_epi64(_tmpg, _tmpk); - _tmp2 = _mm256_unpacklo_epi64(_tmph, _tmpl); - _tmp3 = _mm256_unpackhi_epi64(_tmph, _tmpl); - _tmp4 = _mm256_unpacklo_epi64(_tmpi, _tmpm); - _tmp5 = _mm256_unpackhi_epi64(_tmpi, _tmpm); - _tmp6 = _mm256_unpacklo_epi64(_tmpj, _tmpn); - _tmp7 = _mm256_unpackhi_epi64(_tmpj, _tmpn); - - _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); - _r1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); - _r2 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); - _r3 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); - _r4 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); - _r5 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); - _r6 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); - _r7 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); -} - static NCNN_FORCEINLINE void transpose8x16_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7, __m128i& _r8, __m128i& _r9, __m128i& _ra, __m128i& _rb, __m128i& _rc, __m128i& _rd, __m128i& _re, __m128i& _rf) { __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); @@ -1088,6 +1336,7 @@ static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m5 } #endif // __AVX512F__ +#endif // __AVX2__ #endif // __AVX__ #endif // __SSE2__ diff --git a/src/layer/x86/yolov3detectionoutput_x86.cpp b/src/layer/x86/yolov3detectionoutput_x86.cpp index 10f26945004..175d7343524 100644 --- a/src/layer/x86/yolov3detectionoutput_x86.cpp +++ b/src/layer/x86/yolov3detectionoutput_x86.cpp @@ -18,7 +18,6 @@ #include "yolov3detectionoutput_x86.h" #include -#include namespace ncnn { diff --git a/src/layer/yolodetectionoutput.cpp b/src/layer/yolodetectionoutput.cpp index 967b14751f8..9b9ba7dc289 100644 --- a/src/layer/yolodetectionoutput.cpp +++ b/src/layer/yolodetectionoutput.cpp @@ -16,8 +16,6 @@ #include "layer_type.h" -#include - namespace ncnn { YoloDetectionOutput::YoloDetectionOutput() diff --git a/src/layer/yolov3detectionoutput.cpp b/src/layer/yolov3detectionoutput.cpp index 0cda9616746..494fb6d186a 100644 --- a/src/layer/yolov3detectionoutput.cpp +++ b/src/layer/yolov3detectionoutput.cpp @@ -17,7 +17,6 @@ #include "layer_type.h" #include -#include namespace ncnn { diff --git a/src/mat.cpp b/src/mat.cpp index 6e1cd702522..f758df41d40 100644 --- a/src/mat.cpp +++ b/src/mat.cpp @@ -21,8 +21,6 @@ #include "layer.h" #include "layer_type.h" -#include - #if NCNN_VULKAN #if NCNN_PLATFORM_API #if __ANDROID_API__ >= 26 diff --git a/src/mat_pixel.cpp b/src/mat_pixel.cpp index ce9d4c479e0..221c7e5b2f8 100644 --- a/src/mat_pixel.cpp +++ b/src/mat_pixel.cpp @@ -15,7 +15,7 @@ #include "mat.h" #include -#include + #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/mat_pixel_affine.cpp b/src/mat_pixel_affine.cpp index c2abe363d96..934fe22b1d5 100644 --- a/src/mat_pixel_affine.cpp +++ b/src/mat_pixel_affine.cpp @@ -17,7 +17,7 @@ #include #endif // __ARM_NEON #include -#include + #include "platform.h" namespace ncnn { diff --git a/src/mat_pixel_resize.cpp b/src/mat_pixel_resize.cpp index 7d171338469..e8f138d2a54 100644 --- a/src/mat_pixel_resize.cpp +++ b/src/mat_pixel_resize.cpp @@ -15,7 +15,7 @@ #include "mat.h" #include -#include + #if __ARM_NEON #include #endif // __ARM_NEON diff --git a/src/net.cpp b/src/net.cpp index aed2f20a48e..f4e70e98ae0 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -610,67 +610,41 @@ int NetPrivate::forward_layer(int layer_index, std::vector& blob_mats, std: int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Option& opt) const { - // clang-format off - // *INDENT-OFF* -#if NCNN_ARM82 - if (opt.use_fp16_storage && cpu_support_arm_asimdhp()) + if (bottom_blob.elembits() == 32) { - if (bottom_blob.elembits() == 32 && layer->support_fp16_storage) + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (opt.use_fp16_storage && cpu_support_arm_asimdhp() && layer->support_fp16_storage) { Mat bottom_blob_fp16; cast_float32_to_float16(bottom_blob, bottom_blob_fp16, opt); bottom_blob = bottom_blob_fp16; } - if (bottom_blob.elembits() == 16 && !layer->support_fp16_storage) - { - Mat bottom_blob_fp32; - cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); - bottom_blob = bottom_blob_fp32; - } - } - else + else #endif // NCNN_ARM82 #if NCNN_RVV - if (opt.use_fp16_storage && cpu_support_riscv_v() && cpu_support_riscv_zfh()) - { - if (bottom_blob.elembits() == 32 && layer->support_fp16_storage) + if (opt.use_fp16_storage && cpu_support_riscv_v() && cpu_support_riscv_zfh() && layer->support_fp16_storage) { Mat bottom_blob_fp16; cast_float32_to_float16(bottom_blob, bottom_blob_fp16, opt); bottom_blob = bottom_blob_fp16; } - if (bottom_blob.elembits() == 16 && !layer->support_fp16_storage) - { - Mat bottom_blob_fp32; - cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); - bottom_blob = bottom_blob_fp32; - } - } - else + else #endif // NCNN_RVV #if NCNN_BF16 - if (opt.use_bf16_storage) - { - if (bottom_blob.elembits() == 32 && layer->support_bf16_storage) + if (opt.use_bf16_storage && layer->support_bf16_storage) { Mat bottom_blob_bf16; cast_float32_to_bfloat16(bottom_blob, bottom_blob_bf16, opt); bottom_blob = bottom_blob_bf16; } - if (bottom_blob.elembits() == 16 && !layer->support_bf16_storage) - { - Mat bottom_blob_fp32; - cast_bfloat16_to_float32(bottom_blob, bottom_blob_fp32, opt); - bottom_blob = bottom_blob_fp32; - } - } - else #endif // NCNN_BF16 - { - // no type conversion + + // *INDENT-ON* + // clang-format on } - // *INDENT-ON* - // clang-format on int dst_elempack = 1; if (opt.use_packing_layout) @@ -746,6 +720,42 @@ int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Optio bottom_blob = bottom_blob_packed; } + if (bottom_blob.elembits() == 16) + { + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (opt.use_fp16_storage && cpu_support_arm_asimdhp() && !layer->support_fp16_storage) + { + Mat bottom_blob_fp32; + cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); + bottom_blob = bottom_blob_fp32; + } + else +#endif // NCNN_ARM82 +#if NCNN_RVV + if (opt.use_fp16_storage && cpu_support_riscv_v() && cpu_support_riscv_zfh() && !layer->support_fp16_storage) + { + Mat bottom_blob_fp32; + cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); + bottom_blob = bottom_blob_fp32; + } + else +#endif // NCNN_RVV +#if NCNN_BF16 + if (opt.use_bf16_storage && !layer->support_bf16_storage) + { + Mat bottom_blob_fp32; + cast_bfloat16_to_float32(bottom_blob, bottom_blob_fp32, opt); + bottom_blob = bottom_blob_fp32; + } +#endif // NCNN_BF16 + + // *INDENT-ON* + // clang-format on + } + return 0; } diff --git a/src/pipeline.cpp b/src/pipeline.cpp index efdaec80bde..8aed60e4803 100644 --- a/src/pipeline.cpp +++ b/src/pipeline.cpp @@ -19,8 +19,6 @@ #include "pipelinecache.h" #include "option.h" -#include - #if __ANDROID_API__ >= 26 #include #endif // __ANDROID_API__ >= 26 diff --git a/src/platform.h.in b/src/platform.h.in index 0ae8f708817..be1dd508388 100644 --- a/src/platform.h.in +++ b/src/platform.h.in @@ -20,6 +20,7 @@ #cmakedefine01 NCNN_SIMPLEOCV #cmakedefine01 NCNN_SIMPLEOMP #cmakedefine01 NCNN_SIMPLESTL +#cmakedefine01 NCNN_SIMPLEMATH #cmakedefine01 NCNN_THREADS #cmakedefine01 NCNN_BENCHMARK #cmakedefine01 NCNN_C_API @@ -245,6 +246,14 @@ private: #include #endif +// simplemath +#if NCNN_SIMPLEMATH +#include "simplemath.h" +#else +#include +#include +#endif + #endif // __cplusplus #if NCNN_STDIO diff --git a/src/simplemath.cpp b/src/simplemath.cpp new file mode 100644 index 00000000000..d48d23e3c20 --- /dev/null +++ b/src/simplemath.cpp @@ -0,0 +1,622 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "platform.h" + +#if NCNN_SIMPLEMATH + +#include "simplemath.h" +#define __HI(X) *(1 + (short*)&x) +#define __LO(X) *(short*)&x +#define INFINITY (1.0 / 0) +#define FE_TONEAREST 0 +#define FE_DOWNWARD 1024 +#define FE_UPWARD 2048 +#define FE_TOWARDZERO 3072 + +/* +* ==================================================== +* some useful constants +* ==================================================== +*/ +static const float PI = 3.14159265358979323846; +static const float PI_2 = 1.57079632679489661923; /* PI/2 */ +static const float E = 2.71828182845904523536; + +/* re-interpret the bit pattern of a uint32 as an IEEE-754 float */ +static float uint32_as_float(uint32_t a) +{ + float r; + float* rp = &r; + uint32_t* ap = &a; + + *rp = *(float*)ap; + + return r; +} + +#ifdef __cplusplus +extern "C" { +#endif +/* +* ==================================================== +* Discontinuous function +* ==================================================== +*/ +float fabs(float x) +{ + return x > 0 ? x : -x; +} + +float fabsf(float x) +{ + return fabs(x); +} + +float fmod(float numer, float denom) +{ + if (denom == 0.0) + { + return numer; + } + if (numer <= denom) + { + return numer; + } + + int quotient = static_cast(numer / denom); + return numer - quotient * denom; +} + +float floor(float x) +{ + int intValue = static_cast(x); + if (x < 0 && x != intValue) + { + intValue -= 1; + } + return intValue; +} + +float floorf(float x) +{ + return floor(x); +} + +float round(float x) +{ + float ret = x > 0 ? floor(x + 0.5) : ceil(x - 0.5); + return ret; +} + +float roundf(float x) +{ + return round(x); +} + +float ceilf(float x) +{ + return ceil(x); +} + +float ceil(float x) +{ + int intValue = static_cast(x); + if (x == intValue) + { + return x; + } + return floor(x + 1); +} + +float fmaxf(float x, float y) +{ + return x > y ? x : y; +} + +float truncf(float x) +{ + int intValue = static_cast(x); + return static_cast(intValue); +} + +float frac(float x) +{ + return x - floor(x); +} + +/* +* ==================================================== +* trigonometric functions +* ==================================================== +*/ + +/* + modify from https://developer.download.nvidia.cn/cg/sin.html +*/ +float sinf(float a) +{ + const int x = 0; + const int y = 1; + const int z = 2; + const int w = 3; + + float c0[4] = {0.0, 0.5, 1.0, 0.0}; + float c1[4] = {0.25, -9.0, 0.75, 0.159154943091}; + float c2[4] = {24.9808039603, -24.9808039603, -60.1458091736, 60.1458091736}; + float c3[4] = {85.4537887573, -85.4537887573, -64.9393539429, 64.9393539429}; + float c4[4] = {19.7392082214, -19.7392082214, -1.0, 1.0}; + float r0[3], r1[3], r2[3]; + + // r1.x = c1.w * a - c1.x + r1[x] = c1[w] * a - c1[x]; + // r1.y = frac( r1.x ); + r1[y] = frac(r1[x]); + // r2.x = (float) ( r1.y < c1.x ); + r2[x] = (float)(r1[y] < c1[x]); + // r2.yz = (float2) ( r1.yy >= c1.yz ); + r2[y] = (float)(r1[y] >= c1[y]); + r2[z] = (float)(r1[y] >= c1[z]); + // r2.y = dot( r2, c4.zwz ); + r2[y] = r2[x] * c4[z] + r2[y] * c4[w] + r2[z] * c4[z]; + + // r0 = c0.xyz - r1.yyy + r0[x] = c0[x] - r1[y]; + r0[y] = c0[y] - r1[y]; + r0[z] = c0[z] - r1[y]; + + // r0 = r0 * r0 + r0[x] = r0[x] * r0[x]; + r0[y] = r0[y] * r0[y]; + r0[z] = r0[z] * r0[z]; + + // r1 = c2.xyx * r0 + c2.zwz + r1[x] = c2[x] * r0[x] + c2[z]; + r1[y] = c2[y] * r0[y] + c2[w]; + r1[z] = c2[x] * r0[z] + c2[z]; + + // r1 = r1 * r0 + c3.xyx + r1[x] = r1[x] * r0[x] + c3[x]; + r1[y] = r1[y] * r0[y] + c3[y]; + r1[z] = r1[z] * r0[z] + c3[x]; + + // r1 = r1 * r0 + c3.zwz + r1[x] = r1[x] * r0[x] + c3[z]; + r1[y] = r1[y] * r0[y] + c3[w]; + r1[z] = r1[z] * r0[z] + c3[z]; + + // r1 = r1 * r0 + c4.xyx + r1[x] = r1[x] * r0[x] + c4[x]; + r1[y] = r1[y] * r0[y] + c4[y]; + r1[z] = r1[z] * r0[z] + c4[x]; + + // r1 = r1 * r0 + c4.zwz + r1[x] = r1[x] * r0[x] + c4[z]; + r1[y] = r1[y] * r0[y] + c4[w]; + r1[z] = r1[z] * r0[z] + c4[z]; + + //r0.x = dot(r1, -r2) + r0[x] = -(r1[x] * r2[x] + r1[y] * r2[y] + r1[z] * r2[z]); + + return r0[x]; +} + +float cosf(float x) +{ + return sinf(PI_2 + x); +} + +float tanf(float x) +{ + return sinf(x) / cosf(x); +} + +/* copy from https://developer.download.nvidia.cn/cg/asin.html */ +float asinf(float x) +{ + float negate = float(x < 0); + x = fabs(x); + float ret = -0.0187293; + ret *= x; + ret += 0.0742610; + ret *= x; + ret -= 0.2121144; + ret *= x; + ret += 1.5707288; + ret = PI * 0.5 - sqrt(1.0 - x) * ret; + return ret - 2 * negate * ret; +} + +/* copy from https://developer.download.nvidia.cn/cg/acos.html */ +float acosf(float x) +{ + float negate = float(x < 0); + x = fabs(x); + float ret = -0.0187293; + ret = ret * x; + ret = ret + 0.0742610; + ret = ret * x; + ret = ret - 0.2121144; + ret = ret * x; + ret = ret + 1.5707288; + ret = ret * sqrt(1.0 - x); + ret = ret - 2 * negate * ret; + return negate * PI + ret; +} + +/* copy from https://developer.download.nvidia.cn/cg/atan.html */ +float atanf(float a) +{ + if (a < 0) + { + return -atanf(-a); + } + if (a > 1) + { + return PI_2 - atanf(1 / a); + } + float s = a * a; + float r = 0.0027856871020048857; + + r = r * s - 0.015866000205278397; + r = r * s + 0.042472220957279205; + r = r * s - 0.07497530430555344f; + r = r * s + 0.10644879937171936; + r = r * s - 0.14207030832767487; + r = r * s + 0.19993454217910767f; + r = r * s - 0.33333146572113037f; + r = r * s; + return r * a + a; +} + +float atan2f(float y, float x) +{ + if (x == 0 && y == 0) + { + // error + return 0; + } + if (y == 0) + { + return x > 0 ? 0 : PI; + } + if (x == 0) + { + return copysignf(PI_2, y); + } + + if (x > 0 && y > 0) + { + return atanf(y / x); + } + else if (x < 0 && y > 0) + { + return PI - atanf(y / -x); + } + else if (x > 0 && y < 0) + { + return -atanf(-y / x); + } + else + { + return -PI + atanf(-y / -x); + } +} + +float tanhf(float v) +{ + if (v >= 8 || v <= -8) + { + return copysignf(1, v); + } + float exp2v = expf(2 * v); + return (exp2v - 1) / (exp2v + 1); +} + +/* +* ==================================================== +* power functions +* ==================================================== +*/ + +float sqrtf(float x) +{ + return powf(x, 0.5); +} + +float sqrt(float x) +{ + return sqrtf(x); +} + +float powf(float x, float y) +{ + return expf(y * logf(x)); +} + +/* +* ==================================================== +* exponential and logarithm functions +* ==================================================== +*/ + +/* copy and modify from https://zhuanlan.zhihu.com/p/541466411 */ +float logf(float x) +{ + static const float + ln2_hi + = 6.93147180369123816490e-01, /* 3fe62e42 fee00000 */ + ln2_lo = 1.90821492927058770002e-10, /* 3dea39ef 35793c76 */ + two25 = 3.3554432e+07, + Lg1 = 6.666666666666735130e-01, /* 3FE55555 55555593 */ + Lg2 = 3.999999999940941908e-01, /* 3FD99999 9997FA04 */ + Lg3 = 2.857142874366239149e-01, /* 3FD24924 94229359 */ + Lg4 = 2.222219843214978396e-01, /* 3FCC71C5 1D8E78AF */ + Lg5 = 1.818357216161805012e-01, /* 3FC74664 96CB03DE */ + Lg6 = 1.531383769920937332e-01, /* 3FC39A09 D078C69F */ + Lg7 = 1.479819860511658591e-01; /* 3FC2F112 DF3E5244 */ + + static float zero = 0.0; + float f, s, z, R, w, t1, t2, dk; + short k, hx, i; + unsigned short lx; + + hx = __HI(x); /* high word of x */ + lx = __LO(x); /* low word of x */ + + k = 0; + if (hx < 0x0080) + { /* x < 2**-126 */ + if (((hx & 0x7fff) | lx) == 0) + return -two25 / zero; /* log(+-0)=-inf */ + if (hx < 0) return (x - x) / zero; /* log(-#) = NaN */ + k -= 25; + x *= two25; /* subnormal number, scale up x */ + hx = __HI(x); /* high word of x */ + } + + if (hx >= 0x7f80) return x + x; + k += (hx >> 7) - 127; + hx &= 0x007f; + i = (hx + 0x4b) & 0x0080; + __HI(x) = hx | (i ^ 0x3f80); /* normalize x or x/2 */ + k += (i >> 7); + f = x - 1.0f; + + s = f / (2.0f + f); + dk = (float)k; + z = s * s; + w = z * z; + t1 = w * (Lg2 + w * (Lg4 + w * Lg6)); + t2 = z * (Lg1 + w * (Lg3 + w * (Lg5 + w * Lg7))); + R = t2 + t1; + if (k == 0) + return f - s * (f - R); + else + return dk * ln2_hi - ((s * (f - R) - dk * ln2_lo) - f); +} + +/* copy from https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff */ +float expf(float a) +{ + if (a < 0) + { + float tmp = expf(-a); + + float ret = 1 / tmp; + + return ret; + } + float f, r, j; + int i; + + // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) + j = 1.442695f * a; + j = round(j) + 12582912.f; // There is a bug, and the program lives on it. + j = j - 12582912.f; + // j = fmaf(1.442695f, a, 12582912.f) - 12582912.f; // 0x1.715476p0, 0x1.8p23 + f = fmaf(j, -6.93145752e-1f, a); // -0x1.62e400p-1 // log_2_hi + f = fmaf(j, -1.42860677e-6f, f); // -0x1.7f7d1cp-20 // log_2_lo + i = (int)j; + // approximate r = exp(f) on interval [-log(2)/2, +log(2)/2] + r = 1.37805939e-3f; // 0x1.694000p-10 + r = fmaf(r, f, 8.37312452e-3f); // 0x1.125edcp-7 + r = fmaf(r, f, 4.16695364e-2f); // 0x1.555b5ap-5 + r = fmaf(r, f, 1.66664720e-1f); // 0x1.555450p-3 + r = fmaf(r, f, 4.99999851e-1f); // 0x1.fffff6p-2 + r = fmaf(r, f, 1.00000000e+0f); // 0x1.000000p+0 + r = fmaf(r, f, 1.00000000e+0f); // 0x1.000000p+0 + + float s, t; + uint32_t ia; + // exp(a) = 2**i * r + ia = (i > 0) ? 0 : 0x83000000u; + s = uint32_as_float(0x7f000000u + ia); + t = uint32_as_float(((uint32_t)i << 23) - ia); + r = r * s; + r = r * t; + + // handle special cases: severe overflow / underflow + if (fabsf(a) >= 104.0f) r = (a > 0) ? INFINITY : 0.0f; + + return r; +} + +float frexp(float x, int* y) +{ + int hx, k; + hx = __HI(x); + k = (hx >> 7) & 0x00ff; + k = k - 127; + __HI(x) = hx & 0x807f; + __HI(x) = __HI(x) | 0x3f80; + + *y = k + 1; // y in [1/2, 1) + return x / 2; +} + +float log(float x) +{ + return logf(x); +} + +float log10f(float x) +{ + static const float ln10 = 2.3025850929940456840179914546844; + return logf(x) / ln10; +} + +/* +* ==================================================== +* probability functions +* ==================================================== +*/ + +/* copy from https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff */ +float erf(float a) +{ + float r, s, t, u; + + t = fabsf(a); + s = a * a; + if (t > 0.927734375f) + { // 475/512 + // maximum error 0.99527 ulp + r = fmaf(-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = fmaf(-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = fmaf(r, s, u); + r = fmaf(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = fmaf(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = fmaf(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = fmaf(r, t, -t); + r = 1.0f - expf(r); + r = copysignf(r, a); + } + else + { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = fmaf(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = fmaf(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = fmaf(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = fmaf(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = fmaf(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = fmaf(r, a, a); + } + return r; +} + +float erfcf(float x) +{ + return 1.0 - erf(x); +} + +/* +* ==================================================== +* other functions +* ==================================================== +*/ + +int msb(unsigned int v) +{ + static const int pos[32] = {0, 1, 28, 2, 29, 14, 24, 3, + 30, 22, 20, 15, 25, 17, 4, 8, 31, 27, 13, 23, 21, 19, + 16, 7, 26, 12, 18, 6, 11, 5, 10, 9 + }; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v = (v >> 1) + 1; + return pos[(v * 0x077CB531UL) >> 27]; +} + +float fmaf(float x, float y, float z) +{ + float tmp = x * y; + float ret = tmp + z; + return ret; +} + +float copysignf(float x, float y) +{ + return fabsf(x) * (y > 0 ? 1 : -1); +} + +int round_mode = 0; +void fesetround(int mode) +{ + round_mode = mode; +} + +int fegetround() +{ + return round_mode; +} + +float nearbyintf(float x) +{ + int intPart = static_cast(x); + float floatPart = fabs(x - intPart); + if (floatPart == 0) + { + return x; + } + + if (x > 0) + { + if (round_mode == FE_DOWNWARD || round_mode == FE_TOWARDZERO) + { + return static_cast(intPart); + } + if (round_mode == FE_UPWARD) + { + return static_cast(intPart) + 1.0; + } + if (round_mode == FE_TONEAREST) + { + if (floatPart == 0.5) + { + return intPart % 2 == 0 ? static_cast(intPart) : static_cast(intPart) + 1; + } + return round(x); + } + } + if (x < 0) + { + if (round_mode == FE_UPWARD || round_mode == FE_TOWARDZERO) + { + return static_cast(intPart); + } + if (round_mode == FE_DOWNWARD) + { + return static_cast(intPart) - 1.0; + } + if (round_mode == FE_TONEAREST) + { + if (floatPart == 0.5) + { + return intPart % 2 == 0 ? static_cast(intPart) : static_cast(intPart) - 1; + } + return round(x); + } + } +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // NCNN_SIMPLEMATH diff --git a/src/simplemath.h b/src/simplemath.h new file mode 100644 index 00000000000..fd7fa6964eb --- /dev/null +++ b/src/simplemath.h @@ -0,0 +1,102 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef NCNN_SIMPLEMATH_H +#define NCNN_SIMPLEMATH_H + +#include "platform.h" + +#if NCNN_SIMPLEMATH + +#ifdef __cplusplus +extern "C" { +#endif +/* +* ==================================================== +* discrete functions +* ==================================================== +*/ +NCNN_EXPORT float fabs(float); +NCNN_EXPORT float fabsf(float); +NCNN_EXPORT float fmod(float, float); +NCNN_EXPORT float floor(float); +NCNN_EXPORT float floorf(float); +NCNN_EXPORT float round(float); +NCNN_EXPORT float roundf(float); +NCNN_EXPORT float ceil(float); +NCNN_EXPORT float ceilf(float); +NCNN_EXPORT float fmaxf(float, float); +NCNN_EXPORT float truncf(float); +NCNN_EXPORT float frac(float); +/* +* ==================================================== +* trigonometric functions +* ==================================================== +*/ +NCNN_EXPORT float sinf(float); +NCNN_EXPORT float cosf(float); +NCNN_EXPORT float tanf(float); +NCNN_EXPORT float asinf(float); +NCNN_EXPORT float acosf(float); +NCNN_EXPORT float atanf(float); +NCNN_EXPORT float atan2f(float, float); +NCNN_EXPORT float tanhf(float); + +/* +* ==================================================== +* power functions +* ==================================================== +*/ +NCNN_EXPORT float sqrtf(float); +NCNN_EXPORT float sqrt(float); +NCNN_EXPORT float powf(float, float); + +/* +* ==================================================== +* exponential and logarithm functions +* ==================================================== +*/ +NCNN_EXPORT float expf(float); +NCNN_EXPORT float frexp(float, int*); +NCNN_EXPORT float logf(float); +NCNN_EXPORT float log(float); +NCNN_EXPORT float log10f(float); + +/* +* ==================================================== +* probability functions +* ==================================================== +*/ +NCNN_EXPORT float erf(float); +NCNN_EXPORT float erfcf(float); + +/* +* ==================================================== +* other functions +* ==================================================== +*/ +NCNN_EXPORT int msb(unsigned int); +NCNN_EXPORT float fmaf(float, float, float); +NCNN_EXPORT float copysignf(float, float); +NCNN_EXPORT void fesetround(int); +NCNN_EXPORT int fegetround(); +NCNN_EXPORT float nearbyintf(float); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // NCNN_SIMPLEMATH + +#endif // NCNN_SIMPLEMATH_H \ No newline at end of file diff --git a/src/stb_image.h b/src/stb_image.h index 8d9fc9c581f..1b4b337328e 100644 --- a/src/stb_image.h +++ b/src/stb_image.h @@ -589,7 +589,7 @@ STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const ch #include #if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) -#include // ldexp, pow + // ldexp, pow #endif #ifndef STBI_NO_STDIO diff --git a/src/stb_image_write.h b/src/stb_image_write.h index e4b32ed1bc3..aa397c09d53 100644 --- a/src/stb_image_write.h +++ b/src/stb_image_write.h @@ -214,7 +214,7 @@ STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean); #include #include #include -#include + #if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED)) // ok diff --git a/tests/test_clip.cpp b/tests/test_clip.cpp index 02ad73352cc..553085e2d63 100644 --- a/tests/test_clip.cpp +++ b/tests/test_clip.cpp @@ -36,10 +36,7 @@ static int test_clip_0() { return 0 || test_clip(RandomMat(5, 6, 7, 24), -1.f, 1.f) - || test_clip(RandomMat(5, 6, 7, 24), -1.f, 1.f) - || test_clip(RandomMat(7, 8, 9, 12), -1.f, 1.f) || test_clip(RandomMat(7, 8, 9, 12), -1.f, 1.f) - || test_clip(RandomMat(3, 4, 5, 13), -1.f, 1.f) || test_clip(RandomMat(3, 4, 5, 13), -1.f, 1.f); } diff --git a/tests/test_convolution1d.cpp b/tests/test_convolution1d.cpp index c8dd55ffe6a..bea75da301c 100644 --- a/tests/test_convolution1d.cpp +++ b/tests/test_convolution1d.cpp @@ -77,7 +77,7 @@ static int test_convolution1d_0() const int s = kdsp[i][2]; const int p = kdsp[i][3]; const int b0 = i % 2; - const int b1 = 1 - b1; + const int b1 = 1 - b0; int ret = 0 || test_convolution1d(9, 1, 1, k, d, s, p, b0) diff --git a/tests/test_convolution_3.cpp b/tests/test_convolution_3.cpp index fa358d0670c..1d0f8f079b6 100644 --- a/tests/test_convolution_3.cpp +++ b/tests/test_convolution_3.cpp @@ -190,6 +190,30 @@ static int test_convolution_int8(int w, int h, int c, int outch, int kernel, int return ret; } + if (kernel == 3 && dilation == 1 && stride == 1) + { + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = true; + opt.use_fp16_packed = false; + opt.use_fp16_storage = false; + opt.use_fp16_arithmetic = false; + opt.use_bf16_storage = false; + opt.use_shader_pack8 = false; + opt.use_image_storage = false; + opt.use_sgemm_convolution = false; + opt.use_winograd_convolution = true; + opt.use_winograd23_convolution = true; + opt.use_winograd43_convolution = false; + + ret = test_layer_opt("Convolution", pd, weights, opt, a, requant ? 1.0f : 0.001f, 0, flag); + if (ret != 0) + { + fprintf(stderr, "test_convolution_int8 failed w=%d h=%d c=%d outch=%d kernel=%d dilation=%d stride=%d pad=%d bias=%d requant=%d act=%d actparams=[%f,%f]\n", w, h, c, outch, kernel, dilation, stride, pad, bias, requant, activation_type, activation_params[0], activation_params[1]); + return ret; + } + } + { ncnn::Option opt; opt.num_threads = 1; @@ -310,6 +334,7 @@ static int test_convolution_1() || test_convolution_int8(4, 20, 16, 24, 3, 1, 1, 1, 0) || test_convolution_int8(6, 7, 64, 64, 3, 1, 2, 0, 1) || test_convolution_int8(25, 33, 16, 15, 3, 1, 1, 1, 0) + || test_convolution_int8(25, 33, 31, 31, 3, 1, 1, 1, 0) || test_convolution_int8(7, 7, 15, 12, 3, 1, 1, 1, 0) || test_convolution_int8(5, 6, 31, 9, 5, 1, 1, 0, 1) || test_convolution_int8(5, 7, 32, 8, 5, 1, 2, 0, 1) diff --git a/tests/test_elu.cpp b/tests/test_elu.cpp index 4d170dc2e29..a8736a3efad 100644 --- a/tests/test_elu.cpp +++ b/tests/test_elu.cpp @@ -26,13 +26,22 @@ static int test_elu(const ncnn::Mat& a) int ret = test_layer("ELU", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_elu failed alpha=%f\n", alpha); + fprintf(stderr, "test_elu failed a.dims=%d a=(%d %d %d %d) alpha=%f\n", a.dims, a.w, a.h, a.d, a.c, alpha); } return ret; } static int test_elu_0() +{ + return 0 + || test_elu(RandomMat(7, 6, 5, 32)) + || test_elu(RandomMat(5, 6, 7, 24)) + || test_elu(RandomMat(7, 8, 9, 12)) + || test_elu(RandomMat(3, 4, 5, 13)); +} + +static int test_elu_1() { return 0 || test_elu(RandomMat(4, 7, 32)) @@ -41,7 +50,7 @@ static int test_elu_0() || test_elu(RandomMat(3, 5, 13)); } -static int test_elu_1() +static int test_elu_2() { return 0 || test_elu(RandomMat(13, 32)) @@ -50,7 +59,7 @@ static int test_elu_1() || test_elu(RandomMat(19, 15)); } -static int test_elu_2() +static int test_elu_3() { return 0 || test_elu(RandomMat(128)) @@ -66,5 +75,6 @@ int main() return 0 || test_elu_0() || test_elu_1() - || test_elu_2(); + || test_elu_2() + || test_elu_3(); } diff --git a/tests/test_gridsample.cpp b/tests/test_gridsample.cpp index 70c96b30480..0e384115352 100644 --- a/tests/test_gridsample.cpp +++ b/tests/test_gridsample.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -15,12 +15,13 @@ #include "layer/gridsample.h" #include "testutil.h" -static int test_gridsample(const ncnn::Mat& a, const ncnn::Mat& grid, int sample_type, int padding_mode, int align_corner) +static int test_gridsample(const ncnn::Mat& a, const ncnn::Mat& grid, int sample_type, int padding_mode, int align_corner, int permute_fusion) { ncnn::ParamDict pd; pd.set(0, sample_type); pd.set(1, padding_mode); pd.set(2, align_corner); + pd.set(3, permute_fusion); std::vector weights(0); @@ -31,9 +32,9 @@ static int test_gridsample(const ncnn::Mat& a, const ncnn::Mat& grid, int sample int ret = test_layer("GridSample", pd, weights, as); if (ret != 0) { - fprintf(stderr, "test_gridsample failed a.dims=%d a=(%d %d %d %d) grid.dims=%d grid=(%d %d %d %d) sample_type=%d padding_mode=%d align_corner=%d", + fprintf(stderr, "test_gridsample failed a.dims=%d a=(%d %d %d %d) grid.dims=%d grid=(%d %d %d %d) sample_type=%d padding_mode=%d align_corner=%d permute_fusion=%d", a.dims, a.w, a.h, a.d, a.c, grid.dims, grid.w, grid.h, grid.d, grid.c, - sample_type, padding_mode, align_corner); + sample_type, padding_mode, align_corner, permute_fusion); } return ret; @@ -42,81 +43,141 @@ static int test_gridsample(const ncnn::Mat& a, const ncnn::Mat& grid, int sample static int test_gridsample_0() { return 0 - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 3, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 3, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 3, 1); + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 1, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 1, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 2, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 2, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 3, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 3, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 1, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 1, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 2, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 2, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 3, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 3, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 1, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 1, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 2, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 2, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 3, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 3, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 1, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 1, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 2, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 2, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 3, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 3, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 1, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 1, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 2, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 2, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 3, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 3, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 1, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 1, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 2, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 2, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 3, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 3, 1, 1); } static int test_gridsample_1() { return 0 - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 3, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 3, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 3, 1); + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 1, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 1, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 2, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 2, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 3, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 3, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 1, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 1, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 2, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 2, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 3, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 3, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 1, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 1, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 2, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 2, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 3, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 3, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 1, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 1, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 2, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 2, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 3, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 3, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 1, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 1, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 2, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 2, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 3, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 3, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 1, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 1, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 2, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 2, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 3, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 3, 1, 1); } static int test_gridsample_2() { return 0 - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 1, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 1, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 2, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 2, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 3, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 3, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 1, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 1, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 2, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 2, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 3, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 3, 1); + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 1, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 1, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 2, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 2, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 3, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 3, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 1, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 1, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 2, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 2, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 3, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 3, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 1, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 1, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 2, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 2, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 3, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 3, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 1, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 1, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 2, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 2, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 3, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 3, 1, 1); } static int test_gridsample_3() { return 0 - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 1, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 1, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 2, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 2, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 3, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 3, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 1, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 1, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 2, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 2, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 3, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 3, 1); + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 1, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 1, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 2, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 2, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 3, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 3, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 1, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 1, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 2, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 2, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 3, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 3, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 1, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 1, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 2, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 2, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 3, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 3, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 1, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 1, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 2, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 2, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 3, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 3, 1, 1); } int main() diff --git a/tests/test_mat_pixel_affine.cpp b/tests/test_mat_pixel_affine.cpp index 817b0f57a3c..94ea366f9e7 100644 --- a/tests/test_mat_pixel_affine.cpp +++ b/tests/test_mat_pixel_affine.cpp @@ -15,7 +15,6 @@ #include "mat.h" #include "prng.h" -#include #include static struct prng_rand_t g_prng_rand_state; diff --git a/tests/test_mat_pixel_resize.cpp b/tests/test_mat_pixel_resize.cpp index 725c30e0bdf..38b8c5ab356 100644 --- a/tests/test_mat_pixel_resize.cpp +++ b/tests/test_mat_pixel_resize.cpp @@ -15,7 +15,6 @@ #include "mat.h" #include "prng.h" -#include #include static struct prng_rand_t g_prng_rand_state; diff --git a/tests/test_selu.cpp b/tests/test_selu.cpp index 83ce4f5de25..3844c94ccf1 100644 --- a/tests/test_selu.cpp +++ b/tests/test_selu.cpp @@ -26,7 +26,7 @@ static int test_selu(const ncnn::Mat& a, float alpha, float lambda) int ret = test_layer("SELU", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_selu failed a.dims=%d a=(%d %d %d) alpha=%f lambda=%f\n", a.dims, a.w, a.h, a.c, alpha, lambda); + fprintf(stderr, "test_selu failed a.dims=%d a=(%d %d %d %d) alpha=%f lambda=%f\n", a.dims, a.w, a.h, a.d, a.c, alpha, lambda); } return ret; @@ -35,25 +35,37 @@ static int test_selu(const ncnn::Mat& a, float alpha, float lambda) static int test_selu_0() { return 0 + || test_selu(RandomMat(7, 6, 5, 32), 1.673264f, 1.050700f) + || test_selu(RandomMat(5, 6, 7, 24), 1.673264f, 1.050700f) + || test_selu(RandomMat(7, 8, 9, 12), 1.673264f, 1.050700f) + || test_selu(RandomMat(3, 4, 5, 13), 1.673264f, 1.050700f); +} + +static int test_selu_1() +{ + return 0 + || test_selu(RandomMat(4, 7, 32), 1.673264f, 1.050700f) || test_selu(RandomMat(5, 7, 24), 1.673264f, 1.050700f) || test_selu(RandomMat(7, 9, 12), 1.673264f, 1.050700f) || test_selu(RandomMat(3, 5, 13), 1.673264f, 1.050700f); } -static int test_selu_1() +static int test_selu_2() { return 0 + || test_selu(RandomMat(13, 32), 1.673264f, 1.050700f) || test_selu(RandomMat(15, 24), 1.673264f, 1.050700f) || test_selu(RandomMat(17, 12), 1.673264f, 1.050700f) || test_selu(RandomMat(19, 15), 1.673264f, 1.050700f); } -static int test_selu_2() +static int test_selu_3() { return 0 || test_selu(RandomMat(128), 1.673264f, 1.050700f) || test_selu(RandomMat(124), 1.673264f, 1.050700f) - || test_selu(RandomMat(127), 1.673264f, 1.050700f); + || test_selu(RandomMat(127), 1.673264f, 1.050700f) + || test_selu(RandomMat(120), 1.673264f, 1.050700f); } int main() @@ -63,5 +75,6 @@ int main() return 0 || test_selu_0() || test_selu_1() - || test_selu_2(); + || test_selu_2() + || test_selu_3(); } diff --git a/tests/test_shufflechannel.cpp b/tests/test_shufflechannel.cpp index f2f344b958e..ad21a184e89 100644 --- a/tests/test_shufflechannel.cpp +++ b/tests/test_shufflechannel.cpp @@ -53,7 +53,10 @@ static int test_shufflechannel_0() || test_shufflechannel(5, 7, 24, 2, 0) || test_shufflechannel(3, 7, 24, 3, 0) || test_shufflechannel(5, 9, 24, 4, 0) + || test_shufflechannel(3, 7, 32, 2, 0) || test_shufflechannel(3, 7, 32, 8, 0) + || test_shufflechannel(5, 7, 48, 2, 0) + || test_shufflechannel(5, 7, 48, 3, 0) || test_shufflechannel(5, 9, 64, 4, 0); } @@ -76,7 +79,10 @@ static int test_shufflechannel_1() || test_shufflechannel(5, 7, 24, 2, 1) || test_shufflechannel(3, 7, 24, 3, 1) || test_shufflechannel(5, 9, 24, 4, 1) + || test_shufflechannel(3, 7, 32, 2, 1) || test_shufflechannel(3, 7, 32, 8, 1) + || test_shufflechannel(5, 7, 48, 2, 1) + || test_shufflechannel(5, 7, 48, 3, 1) || test_shufflechannel(3, 7, 64, 4, 1); } diff --git a/tests/testutil.h b/tests/testutil.h index b879fa527fb..0794bdd463d 100644 --- a/tests/testutil.h +++ b/tests/testutil.h @@ -20,7 +20,6 @@ #include "mat.h" #include "prng.h" -#include #include #include diff --git a/toolchains/aarch64-linux-gnu-c.toolchain.cmake b/toolchains/aarch64-linux-gnu-c.toolchain.cmake index 07b39de87b6..cde92c07070 100644 --- a/toolchains/aarch64-linux-gnu-c.toolchain.cmake +++ b/toolchains/aarch64-linux-gnu-c.toolchain.cmake @@ -11,7 +11,7 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) set(CMAKE_C_FLAGS "-march=armv8-a") set(CMAKE_CXX_FLAGS "-march=armv8-a") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -nodefaultlibs -fno-builtin -fno-stack-protector -nostdinc++ -lc") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -nodefaultlibs -fno-builtin -fno-stack-protector -nostdinc++ -mno-outline-atomics -lc") # cache flags set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 3d09ec1859d..fd5105e612f 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -1734,6 +1734,7 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 0=%d", sample_type) fprintf_param_value(" 1=%d", padding_mode) fprintf_param_value(" 2=%d", align_corner) + fprintf_param_value(" 3=%d", permute_fusion) } else if (layer->type == "GroupNorm") { diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index 3a08cbc249e..0c8326fc942 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -70,6 +70,11 @@ if(Torch_VERSION VERSION_LESS "1.8") message(FATAL_ERROR "pnnx only supports PyTorch >= 1.8") endif() +if(Torch_VERSION VERSION_GREATER_EQUAL "2.1") + # c++17 is required for using torch 2.1+ headers + set(CMAKE_CXX_STANDARD 17) +endif() + if(TorchVision_FOUND) message(STATUS "Building with TorchVision") add_definitions(-DPNNX_TORCHVISION) diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 4819cba617a..58264dfd975 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -154,6 +154,7 @@ set(pnnx_pass_level2_SRCS pass_level2/F_mish.cpp pass_level2/F_normalize.cpp pass_level2/F_pad.cpp + pass_level2/F_pairwise_distance.cpp pass_level2/F_pixel_shuffle.cpp pass_level2/F_pixel_unshuffle.cpp pass_level2/F_prelu.cpp @@ -269,6 +270,8 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_unbind.cpp pass_level2/torch_unsqueeze.cpp pass_level2/torch_var.cpp + pass_level2/torch_view_as_complex.cpp + pass_level2/torch_view_as_real.cpp pass_level2/torch_zeros.cpp pass_level2/torch_zeros_like.cpp pass_level2/torch_stft.cpp @@ -342,8 +345,10 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_constant_expression.cpp pass_level5/fuse_conv1d_batchnorm1d.cpp pass_level5/fuse_conv2d_batchnorm2d.cpp + pass_level5/fuse_conv3d_batchnorm3d.cpp pass_level5/fuse_convtranspose1d_batchnorm1d.cpp pass_level5/fuse_convtranspose2d_batchnorm2d.cpp + pass_level5/fuse_convtranspose3d_batchnorm3d.cpp pass_level5/fuse_contiguous_view.cpp pass_level5/fuse_linear_batchnorm1d.cpp pass_level5/fuse_pad_conv1d.cpp diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index dc8ca72dc7e..134b64839d2 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -47,10 +47,24 @@ static std::string get_basename(const std::string& path) { - std::string base = path.substr(0, path.find_last_of('.')); + std::string dirpath; + std::string filename; + + size_t dirpos = path.find_last_of("/\\"); + if (dirpos != std::string::npos) + { + dirpath = path.substr(0, dirpos + 1); + filename = path.substr(dirpos + 1); + } + else + { + filename = path; + } + + std::string base = filename.substr(0, path.find_last_of('.')); // sanitize - std::replace(base.begin(), base.end(), '-', '_'); - return base; + return dirpath + base; } static void parse_string_list(char* s, std::vector& list) @@ -300,6 +314,11 @@ int main(int argc, char** argv) fprintf(stderr, "\n"); } +#ifdef PNNX_TORCHVISION + // call some vision api to register vision ops :P + (void)vision::cuda_version(); +#endif + for (auto m : customop_modules) { fprintf(stderr, "load custom module %s\n", m.c_str()); diff --git a/tools/pnnx/src/pass_level2/F_batch_norm.cpp b/tools/pnnx/src/pass_level2/F_batch_norm.cpp index e922e458a06..fb73a497c5f 100644 --- a/tools/pnnx/src/pass_level2/F_batch_norm.cpp +++ b/tools/pnnx/src/pass_level2/F_batch_norm.cpp @@ -45,4 +45,38 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm, 10) +class F_batch_norm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 10 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 running_mean +pnnx.Input input_2 0 1 running_var +pnnx.Input input_3 0 1 weight +pnnx.Input input_4 0 1 bias +prim::Constant op_0 0 1 momentum value=* +prim::Constant op_1 0 1 eps value=%eps +aten::_native_batch_norm_legit_no_training op_2 7 3 input weight bias running_mean running_var momentum eps out save_mean save_invstd +pnnx.Output output 3 0 out save_mean save_invstd +)PNNXIR"; + } + + const char* type_str() const + { + return "F.batch_norm"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(op, captured_params, captured_attrs); + + op->outputs.resize(1); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_1, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_pairwise_distance.cpp b/tools/pnnx/src/pass_level2/F_pairwise_distance.cpp new file mode 100644 index 00000000000..8177b25d52f --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_pairwise_distance.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_pairwise_distance : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 x1 +pnnx.Input input_1 0 1 x2 +prim::Constant op_0 0 1 p value=%p +prim::Constant op_1 0 1 eps value=%eps +prim::Constant op_2 0 1 keepdim value=%keepdim +aten::pairwise_distance op_3 5 1 x1 x2 p eps keepdim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pairwise_distance"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pairwise_distance, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp index e7ca7bbf824..8dcfafaf12b 100644 --- a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp @@ -42,4 +42,31 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention, 10) +class F_scaled_dot_product_attention_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +prim::Constant op_0 0 1 dropout_p value=%dropout_p +prim::Constant op_1 0 1 is_causal value=%is_causal +prim::Constant op_2 0 1 scale value=%scale +aten::scaled_dot_product_attention op_3 7 1 query key value attn_mask dropout_p is_causal scale out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.scaled_dot_product_attention"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp b/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp new file mode 100644 index 00000000000..e00ff1371ca --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_view_as_complex : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::view_as_complex op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.view_as_complex"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_view_as_complex, 20) + +} // namespace pnnx \ No newline at end of file diff --git a/tools/pnnx/src/pass_level2/torch_view_as_real.cpp b/tools/pnnx/src/pass_level2/torch_view_as_real.cpp new file mode 100644 index 00000000000..83327e01ef9 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_view_as_real.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_view_as_real : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::view_as_real op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.view_as_real"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_view_as_real, 20) + +} // namespace pnnx \ No newline at end of file diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 5d90c9554fa..43b720fc1e7 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -34,8 +34,10 @@ #include "pass_level5/fuse_constant_expression.h" #include "pass_level5/fuse_conv1d_batchnorm1d.h" #include "pass_level5/fuse_conv2d_batchnorm2d.h" +#include "pass_level5/fuse_conv3d_batchnorm3d.h" #include "pass_level5/fuse_convtranspose1d_batchnorm1d.h" #include "pass_level5/fuse_convtranspose2d_batchnorm2d.h" +#include "pass_level5/fuse_convtranspose3d_batchnorm3d.h" #include "pass_level5/fuse_contiguous_view.h" #include "pass_level5/fuse_layernorm.h" #include "pass_level5/fuse_linear_batchnorm1d.h" @@ -101,8 +103,10 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_conv1d_batchnorm1d(g); fuse_conv2d_batchnorm2d(g); + fuse_conv3d_batchnorm3d(g); fuse_convtranspose1d_batchnorm1d(g); fuse_convtranspose2d_batchnorm2d(g); + fuse_convtranspose3d_batchnorm3d(g); fuse_linear_batchnorm1d(g); fuse_pad_conv1d(g); diff --git a/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.cpp b/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.cpp new file mode 100644 index 00000000000..ea89e99cb53 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.cpp @@ -0,0 +1,138 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_conv3d_batchnorm3d.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_conv3d_batchnorm3d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.Conv3d op_0 1 1 input a in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +nn.BatchNorm3d op_1 1 1 a out num_features=%num_features eps=%eps affine=%affine @running_mean @running_var @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv3d"; + } + + const char* name_str() const + { + return "convbn3d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = captured_params.at("padding_mode"); + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = true; + + // resolve merged conv3d weight and bias + int channels = captured_params.at("num_features").i; + float bn_eps = captured_params.at("eps").f; + bool has_bn_affine = captured_params.at("affine").b; + bool has_conv_bias = captured_params.at("bias").b; + + auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data(); + auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data(); + auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector(); + auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector(); + + // a = bias - slope * mean / sqrt(var + eps) + // b = slope / sqrt(var + eps) + // value = value * b + a + + std::vector a(channels); + std::vector b(channels); + for (int i = 0; i < channels; i++) + { + double sqrt_var = sqrt(bn_running_var[i] + bn_eps); + + if (has_bn_affine) + { + a[i] = (float)(bn_bias[i] - bn_weight[i] * bn_running_mean[i] / sqrt_var); + b[i] = (float)(bn_weight[i] / sqrt_var); + } + else + { + a[i] = (float)(-bn_running_mean[i] / sqrt_var); + b[i] = (float)(1.f / sqrt_var); + } + } + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (has_conv_bias) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + else + { + // init bias as zero + op->attrs["bias"] = Attribute(); + op->attrs["bias"].type = op->attrs["weight"].type; + op->attrs["bias"].shape = {channels}; + op->attrs["bias"].set_float32_data(std::vector(channels, 0.f)); + } + + auto conv_weight = op->attrs["weight"].get_float32_data(); + auto conv_bias = op->attrs["bias"].get_float32_data(); + + const int outch = captured_params.at("out_channels").i; + const int weight_per_outch = op->attrs["weight"].elemcount() / outch; + + for (int i = 0; i < channels; i++) + { + float* conv_weight_outch = conv_weight.data() + weight_per_outch * i; + for (int j = 0; j < weight_per_outch; j++) + { + conv_weight_outch[j] *= b[i]; + } + + conv_bias[i] = conv_bias[i] * b[i] + a[i]; + } + + op->attrs["weight"].set_float32_data(conv_weight); + op->attrs["bias"].set_float32_data(conv_bias); + } +}; + +void fuse_conv3d_batchnorm3d(Graph& graph) +{ + fuse_conv3d_batchnorm3d_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.h b/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.h new file mode 100644 index 00000000000..017201d4d8b --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_conv3d_batchnorm3d(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.cpp b/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.cpp new file mode 100644 index 00000000000..d01eebeed48 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.cpp @@ -0,0 +1,156 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_convtranspose3d_batchnorm3d.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_convtranspose3d_batchnorm3d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ConvTranspose3d op_0 1 1 input a in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride output_padding=%output_padding padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +nn.BatchNorm3d op_1 1 1 a out num_features=%num_features eps=%eps affine=%affine @running_mean @running_var @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ConvTranspose3d"; + } + + const char* name_str() const + { + return "convtransposebn3d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["stride"] = captured_params.at("stride"); + op->params["output_padding"] = captured_params.at("output_padding"); + op->params["padding"] = captured_params.at("padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = true; + + // resolve merged convtranspose3d weight and bias + int channels = captured_params.at("num_features").i; + float bn_eps = captured_params.at("eps").f; + bool has_bn_affine = captured_params.at("affine").b; + bool has_convtranspose_bias = captured_params.at("bias").b; + + auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data(); + auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data(); + auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector(); + auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector(); + + // a = bias - slope * mean / sqrt(var + eps) + // b = slope / sqrt(var + eps) + // value = value * b + a + + std::vector a(channels); + std::vector b(channels); + for (int i = 0; i < channels; i++) + { + double sqrt_var = sqrt(bn_running_var[i] + bn_eps); + + if (has_bn_affine) + { + a[i] = (float)(bn_bias[i] - bn_weight[i] * bn_running_mean[i] / sqrt_var); + b[i] = (float)(bn_weight[i] / sqrt_var); + } + else + { + a[i] = (float)(-bn_running_mean[i] / sqrt_var); + b[i] = (float)(1.f / sqrt_var); + } + } + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (has_convtranspose_bias) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + else + { + // init bias as zero + op->attrs["bias"] = Attribute(); + op->attrs["bias"].type = op->attrs["weight"].type; + op->attrs["bias"].shape = {channels}; + op->attrs["bias"].set_float32_data(std::vector(channels, 0.f)); + } + + auto conv_weight = op->attrs["weight"].get_float32_data(); + auto conv_bias = op->attrs["bias"].get_float32_data(); + + // group-inch/group-outch/group-kh-kw + const int inch = captured_params.at("in_channels").i; + const int outch = captured_params.at("out_channels").i; + const int groups = captured_params.at("groups").i; + const int kd = captured_params.at("kernel_size").ai[0]; + const int kh = captured_params.at("kernel_size").ai[1]; + const int kw = captured_params.at("kernel_size").ai[2]; + + const int outch_g = outch / groups; + const int inch_g = inch / groups; + const int maxk = kd * kh * kw; + + for (int g = 0; g < groups; g++) + { + float* wg = (float*)conv_weight.data() + g * inch_g * outch_g * maxk; + for (int i = 0; i < inch_g; i++) + { + for (int j = 0; j < outch_g; j++) + { + for (int k = 0; k < maxk; k++) + { + wg[(i * outch_g + j) * maxk + k] *= b[g * outch_g + j]; + } + } + } + } + + for (int i = 0; i < channels; i++) + { + conv_bias[i] = conv_bias[i] * b[i] + a[i]; + } + + op->attrs["weight"].set_float32_data(conv_weight); + op->attrs["bias"].set_float32_data(conv_bias); + } +}; + +void fuse_convtranspose3d_batchnorm3d(Graph& graph) +{ + fuse_convtranspose3d_batchnorm3d_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.h b/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.h new file mode 100644 index 00000000000..f15e2f41c66 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_convtranspose3d_batchnorm3d(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_grid_sample.cpp b/tools/pnnx/src/pass_ncnn/F_grid_sample.cpp index 41dfc65ee39..b71e64dcbde 100644 --- a/tools/pnnx/src/pass_ncnn/F_grid_sample.cpp +++ b/tools/pnnx/src/pass_ncnn/F_grid_sample.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -15,7 +15,6 @@ #include "pass_ncnn.h" namespace pnnx { - namespace ncnn { class F_grid_sample : public GraphRewriterPass @@ -61,11 +60,113 @@ pnnx.Output output 1 0 out op->params["1"] = 3; op->params["2"] = captured_params.at("align_corners").b ? 1 : 0; + op->params["3"] = 0; } }; REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_grid_sample, 20) +class F_grid_sample_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_a 0 1 a +pnnx.Input input_b 0 1 b +torch.permute op_0 1 1 b b1 dims=%dims +F.grid_sample op_1 2 1 a b1 out mode=%mode padding_mode=%padding_mode align_corners=%align_corners +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GridSample"; + } + + const char* name_str() const + { + return "permutegridsample"; + } + + bool match(const std::map& captured_params) const + { + const std::vector& dims = captured_params.at("dims").ai; + + if ((dims == std::vector{1, 2, 0}) || (dims == std::vector{1, 2, 3, 0})) + return true; + if ((dims == std::vector{0, 2, 3, 1}) || (dims == std::vector{0, 2, 3, 4, 1})) + return true; + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + if (mode == "bilinear") + op->params["0"] = 1; + if (mode == "nearest") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + const std::string& padding_mode = captured_params.at("padding_mode").s; + if (padding_mode == "zeros") + op->params["1"] = 1; + if (padding_mode == "border") + op->params["1"] = 2; + if (padding_mode == "reflection") + op->params["1"] = 3; + + op->params["2"] = captured_params.at("align_corners").b ? 1 : 0; + + const int batch_index = op->inputs[1]->params["__batch_index"].i; + + const std::vector& dims = captured_params.at("dims").ai; + + int input_rank = (int)op->inputs[0]->shape.size(); + + if (input_rank == 0) + { + // assume input is fine + input_rank = (int)dims.size(); + } + + if (batch_index >= 0 && batch_index < input_rank) + input_rank -= 1; + + if (input_rank > 4) + { + fprintf(stderr, "permute %d-rank tensor is not supported yet!\n", input_rank); + return; + } + + // drop permute batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + if (input_rank != (int)new_dims.size()) + { + fprintf(stderr, "permute %d-rank tensor with %d-rank dims is not possible\n", input_rank, (int)new_dims.size()); + return; + } + + if ((input_rank == 3 && new_dims == std::vector{1, 2, 0}) || (input_rank == 4 && new_dims == std::vector{1, 2, 3, 0})) + op->params["3"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_grid_sample_1, 19) + } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 5ed5632f7fb..346ee0a955f 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -54,6 +54,7 @@ pnnx_add_test(F_max_pool2d) pnnx_add_test(F_max_pool3d) pnnx_add_test(F_normalize) pnnx_add_test(F_pad) +pnnx_add_test(F_pairwise_distance) pnnx_add_test(F_pixel_shuffle) pnnx_add_test(F_pixel_unshuffle) pnnx_add_test(F_prelu) @@ -236,6 +237,8 @@ pnnx_add_test(torch_topk) pnnx_add_test(torch_transpose) pnnx_add_test(torch_unbind) pnnx_add_test(torch_unsqueeze) +pnnx_add_test(torch_view_as_complex) +pnnx_add_test(torch_view_as_real) pnnx_add_test(torch_zeros) pnnx_add_test(torch_zeros_like) @@ -313,8 +316,10 @@ pnnx_add_test(pnnx_expression) pnnx_add_test(pnnx_fold_constant) pnnx_add_test(pnnx_fuse_conv1d_batchnorm1d) pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) +pnnx_add_test(pnnx_fuse_conv3d_batchnorm3d) pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) +pnnx_add_test(pnnx_fuse_convtranspose3d_batchnorm3d) pnnx_add_test(pnnx_fuse_input_unpack) pnnx_add_test(pnnx_fuse_layernorm) pnnx_add_test(pnnx_fuse_linear_batchnorm1d) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index d4d026570d6..34bd0cd9d6f 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -52,6 +52,7 @@ pnnx_ncnn_add_test(F_pixel_unshuffle) pnnx_ncnn_add_test(F_prelu) pnnx_ncnn_add_test(F_relu) pnnx_ncnn_add_test(F_relu6) +pnnx_ncnn_add_test(F_selu) pnnx_ncnn_add_test(F_sigmoid) pnnx_ncnn_add_test(F_silu) pnnx_ncnn_add_test(F_softmax) diff --git a/tools/pnnx/tests/ncnn/test_F_celu.py b/tools/pnnx/tests/ncnn/test_F_celu.py index 04ecc37bafa..da7f879dc39 100644 --- a/tools/pnnx/tests/ncnn/test_F_celu.py +++ b/tools/pnnx/tests/ncnn/test_F_celu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.celu(x) y = F.celu(y, 0.8) z = F.celu(z, 0.5) diff --git a/tools/pnnx/tests/ncnn/test_F_elu.py b/tools/pnnx/tests/ncnn/test_F_elu.py index ea32eff96e7..a5693aa6bd5 100644 --- a/tools/pnnx/tests/ncnn/test_F_elu.py +++ b/tools/pnnx/tests/ncnn/test_F_elu.py @@ -20,30 +20,36 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.elu(x) y = F.elu(y, 1.2) z = F.elu(z, -0.6) - return x, y, z + w = F.elu(w, 0.1) + return x, y, z, w def test(): net = Model() net.eval() torch.manual_seed(0) - x = torch.rand(1, 16) - y = torch.rand(1, 2, 16) - z = torch.rand(1, 3, 12, 16) + x = torch.rand(16) + y = torch.rand(2, 16) + z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_F_elu.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_elu.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + os.system("../../src/pnnx test_F_elu.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") # ncnn inference import test_F_elu_ncnn diff --git a/tools/pnnx/tests/ncnn/test_F_gelu.py b/tools/pnnx/tests/ncnn/test_F_gelu.py index 2f5fe8c7503..0a99e1fc8ee 100644 --- a/tools/pnnx/tests/ncnn/test_F_gelu.py +++ b/tools/pnnx/tests/ncnn/test_F_gelu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.gelu(x) y = F.gelu(y) z = F.gelu(z) diff --git a/tools/pnnx/tests/ncnn/test_F_grid_sample.py b/tools/pnnx/tests/ncnn/test_F_grid_sample.py index c84d38232b1..95ca812eb51 100644 --- a/tools/pnnx/tests/ncnn/test_F_grid_sample.py +++ b/tools/pnnx/tests/ncnn/test_F_grid_sample.py @@ -22,46 +22,56 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, xg1, xg2, y, yg1, yg2): + def forward(self, x, xg1, xg2, xgp1, xgp2, y, yg1, yg2, ygp1, ygp2): # norm to -1 ~ 1 xg1 = xg1 * 2 - 1 xg2 = xg2 * 2 - 1 yg1 = yg1 * 2 - 1 yg2 = yg2 * 2 - 1 - x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='zeros', align_corners=False) - x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='border', align_corners=False) - x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='reflection', align_corners=False) - x = F.grid_sample(x, xg2, mode='nearest', padding_mode='zeros', align_corners=False) - x = F.grid_sample(x, xg1, mode='nearest', padding_mode='border', align_corners=False) - x = F.grid_sample(x, xg2, mode='nearest', padding_mode='reflection', align_corners=False) - x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='zeros', align_corners=False) - x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='border', align_corners=False) - x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='reflection', align_corners=False) - x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='zeros', align_corners=True) - x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='border', align_corners=True) - x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='reflection', align_corners=True) - x = F.grid_sample(x, xg1, mode='nearest', padding_mode='zeros', align_corners=True) - x = F.grid_sample(x, xg2, mode='nearest', padding_mode='border', align_corners=True) - x = F.grid_sample(x, xg1, mode='nearest', padding_mode='reflection', align_corners=True) - x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='zeros', align_corners=True) - x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='border', align_corners=True) - x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='reflection', align_corners=True) - - y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='zeros', align_corners=False) - y = F.grid_sample(y, yg2, mode='bilinear', padding_mode='border', align_corners=False) - y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='reflection', align_corners=False) - y = F.grid_sample(y, yg2, mode='nearest', padding_mode='zeros', align_corners=False) - y = F.grid_sample(y, yg1, mode='nearest', padding_mode='border', align_corners=False) - y = F.grid_sample(y, yg2, mode='nearest', padding_mode='reflection', align_corners=False) - y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='zeros', align_corners=True) - y = F.grid_sample(y, yg2, mode='bilinear', padding_mode='border', align_corners=True) - y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='reflection', align_corners=True) - y = F.grid_sample(y, yg2, mode='nearest', padding_mode='zeros', align_corners=True) - y = F.grid_sample(y, yg1, mode='nearest', padding_mode='border', align_corners=True) - y = F.grid_sample(y, yg2, mode='nearest', padding_mode='reflection', align_corners=True) - - return x, y + x0 = F.grid_sample(x, xg1, mode='bilinear', padding_mode='zeros', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='bilinear', padding_mode='border', align_corners=False) + x0 = F.grid_sample(x0, xg1, mode='bilinear', padding_mode='reflection', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='nearest', padding_mode='zeros', align_corners=False) + x0 = F.grid_sample(x0, xg1, mode='nearest', padding_mode='border', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='nearest', padding_mode='reflection', align_corners=False) + x0 = F.grid_sample(x0, xg1, mode='bicubic', padding_mode='zeros', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='bicubic', padding_mode='border', align_corners=False) + x0 = F.grid_sample(x0, xg1, mode='bicubic', padding_mode='reflection', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='bilinear', padding_mode='zeros', align_corners=True) + x0 = F.grid_sample(x0, xg1, mode='bilinear', padding_mode='border', align_corners=True) + x0 = F.grid_sample(x0, xg2, mode='bilinear', padding_mode='reflection', align_corners=True) + x0 = F.grid_sample(x0, xg1, mode='nearest', padding_mode='zeros', align_corners=True) + x0 = F.grid_sample(x0, xg2, mode='nearest', padding_mode='border', align_corners=True) + x0 = F.grid_sample(x0, xg1, mode='nearest', padding_mode='reflection', align_corners=True) + x0 = F.grid_sample(x0, xg2, mode='bicubic', padding_mode='zeros', align_corners=True) + x0 = F.grid_sample(x0, xg1, mode='bicubic', padding_mode='border', align_corners=True) + x0 = F.grid_sample(x0, xg2, mode='bicubic', padding_mode='reflection', align_corners=True) + + y0 = F.grid_sample(y, yg1, mode='bilinear', padding_mode='zeros', align_corners=False) + y0 = F.grid_sample(y0, yg2, mode='bilinear', padding_mode='border', align_corners=False) + y0 = F.grid_sample(y0, yg1, mode='bilinear', padding_mode='reflection', align_corners=False) + y0 = F.grid_sample(y0, yg2, mode='nearest', padding_mode='zeros', align_corners=False) + y0 = F.grid_sample(y0, yg1, mode='nearest', padding_mode='border', align_corners=False) + y0 = F.grid_sample(y0, yg2, mode='nearest', padding_mode='reflection', align_corners=False) + y0 = F.grid_sample(y0, yg1, mode='bilinear', padding_mode='zeros', align_corners=True) + y0 = F.grid_sample(y0, yg2, mode='bilinear', padding_mode='border', align_corners=True) + y0 = F.grid_sample(y0, yg1, mode='bilinear', padding_mode='reflection', align_corners=True) + y0 = F.grid_sample(y0, yg2, mode='nearest', padding_mode='zeros', align_corners=True) + y0 = F.grid_sample(y0, yg1, mode='nearest', padding_mode='border', align_corners=True) + y0 = F.grid_sample(y0, yg2, mode='nearest', padding_mode='reflection', align_corners=True) + + xgp1 = xgp1.permute(0, 2, 3, 1) + xgp2 = xgp2.permute(0, 2, 3, 1) + ygp1 = ygp1.permute(0, 2, 3, 4, 1) + ygp2 = ygp2.permute(0, 2, 3, 4, 1) + + x1 = F.grid_sample(x, xgp1, mode='bilinear', padding_mode='zeros', align_corners=False) + x1 = F.grid_sample(x1, xgp2, mode='bilinear', padding_mode='border', align_corners=False) + + y1 = F.grid_sample(y, ygp1, mode='bilinear', padding_mode='zeros', align_corners=False) + y1 = F.grid_sample(y1, ygp2, mode='bilinear', padding_mode='border', align_corners=False) + return x0, y0, x1, y1 def test(): net = Model() @@ -71,25 +81,29 @@ def test(): x = torch.rand(1, 3, 12, 16) xg1 = torch.rand(1, 21, 27, 2) xg2 = torch.rand(1, 12, 16, 2) + xgp1 = torch.rand(1, 2, 21, 27) + xgp2 = torch.rand(1, 2, 12, 16) y = torch.rand(1, 5, 10, 12, 16) yg1 = torch.rand(1, 10, 21, 27, 3) yg2 = torch.rand(1, 10, 12, 16, 3) + ygp1 = torch.rand(1, 3, 10, 21, 27) + ygp2 = torch.rand(1, 3, 10, 12, 16) - a0, a1 = net(x, xg1, xg2, y, yg1, yg2) + a0, a1, a2, a3 = net(x, xg1, xg2, xgp1, xgp2, y, yg1, yg2, ygp1, ygp2) # export torchscript - mod = torch.jit.trace(net, (x, xg1, xg2, y, yg1, yg2)) + mod = torch.jit.trace(net, (x, xg1, xg2, xgp1, xgp2, y, yg1, yg2, ygp1, ygp2)) mod.save("test_F_grid_sample.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_grid_sample.pt inputshape=[1,3,12,16],[1,21,27,2],[1,12,16,2],[1,5,10,12,16],[1,10,21,27,3],[1,10,12,16,3]") + os.system("../../src/pnnx test_F_grid_sample.pt inputshape=[1,3,12,16],[1,21,27,2],[1,12,16,2],[1,2,21,27],[1,2,12,16],[1,5,10,12,16],[1,10,21,27,3],[1,10,12,16,3],[1,3,10,21,27],[1,3,10,12,16]") # ncnn inference import test_F_grid_sample_ncnn - b0, b1 = test_F_grid_sample_ncnn.test_inference() + b0, b1, b2, b3 = test_F_grid_sample_ncnn.test_inference() - return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + return torch.allclose(a0, b0, 1e-6, 1e-6) and torch.allclose(a1, b1, 1e-6, 1e-6) and torch.allclose(a2, b2, 1e-6, 1e-6) and torch.allclose(a3, b3, 1e-6, 1e-6) if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/ncnn/test_F_hardsigmoid.py b/tools/pnnx/tests/ncnn/test_F_hardsigmoid.py index 0d993a16918..0636575f2d4 100644 --- a/tools/pnnx/tests/ncnn/test_F_hardsigmoid.py +++ b/tools/pnnx/tests/ncnn/test_F_hardsigmoid.py @@ -23,11 +23,16 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.hardsigmoid(x) y = F.hardsigmoid(y) z = hardsigmoid_forward_0(z) - return x, y, z + w = F.hardsigmoid(w) + return x, y, z, w def test(): net = Model() @@ -37,16 +42,17 @@ def test(): x = torch.rand(16) y = torch.rand(2, 16) z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_F_hardsigmoid.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_hardsigmoid.pt inputshape=[16],[2,16],[3,12,16]") + os.system("../../src/pnnx test_F_hardsigmoid.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") # ncnn inference import test_F_hardsigmoid_ncnn diff --git a/tools/pnnx/tests/ncnn/test_F_hardswish.py b/tools/pnnx/tests/ncnn/test_F_hardswish.py index bacf9986974..30fccc477a2 100644 --- a/tools/pnnx/tests/ncnn/test_F_hardswish.py +++ b/tools/pnnx/tests/ncnn/test_F_hardswish.py @@ -30,12 +30,16 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.hardswish(x) y = hardswish_forward_0(y) z = hardswish_forward_1(z) - z = hardswish_forward_2(z) - return x, y, z + w = hardswish_forward_2(w) + return x, y, z, w def test(): net = Model() @@ -45,16 +49,17 @@ def test(): x = torch.rand(16) y = torch.rand(2, 16) z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_F_hardswish.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_hardswish.pt inputshape=[16],[2,16],[3,12,16]") + os.system("../../src/pnnx test_F_hardswish.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") # ncnn inference import test_F_hardswish_ncnn diff --git a/tools/pnnx/tests/ncnn/test_F_hardtanh.py b/tools/pnnx/tests/ncnn/test_F_hardtanh.py index 787bf0b3d32..95c0c2aa16f 100644 --- a/tools/pnnx/tests/ncnn/test_F_hardtanh.py +++ b/tools/pnnx/tests/ncnn/test_F_hardtanh.py @@ -20,11 +20,16 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.hardtanh(x) y = F.hardtanh(y, -1, 1) z = F.hardtanh(z, -0.1, 0.1) - return x, y, z + w = F.hardtanh(w, -0.2, 0.3) + return x, y, z, w def test(): net = Model() @@ -34,16 +39,17 @@ def test(): x = torch.rand(16) y = torch.rand(2, 16) z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_F_hardtanh.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_hardtanh.pt inputshape=[16],[2,16],[3,12,16]") + os.system("../../src/pnnx test_F_hardtanh.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") # ncnn inference import test_F_hardtanh_ncnn diff --git a/tools/pnnx/tests/ncnn/test_F_leaky_relu.py b/tools/pnnx/tests/ncnn/test_F_leaky_relu.py index 19788c83c91..4606ddce203 100644 --- a/tools/pnnx/tests/ncnn/test_F_leaky_relu.py +++ b/tools/pnnx/tests/ncnn/test_F_leaky_relu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.leaky_relu(x) y = F.leaky_relu(y, 0.1) z = F.leaky_relu(z, -0.22) diff --git a/tools/pnnx/tests/ncnn/test_F_log_softmax.py b/tools/pnnx/tests/ncnn/test_F_log_softmax.py index 3f53229b207..4b0d295f664 100644 --- a/tools/pnnx/tests/ncnn/test_F_log_softmax.py +++ b/tools/pnnx/tests/ncnn/test_F_log_softmax.py @@ -20,12 +20,16 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.log_softmax(x, 0) y = F.log_softmax(y, 1) z = F.log_softmax(z, 2) - z2 = F.log_softmax(z, -1) - return x, y, z, z2 + # w = F.log_softmax(w, -1) TODO + return x, y, z, w def test(): net = Model() @@ -35,16 +39,17 @@ def test(): x = torch.rand(16) y = torch.rand(2, 16) z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_F_log_softmax.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_log_softmax.pt inputshape=[16],[2,16],[3,12,16]") + os.system("../../src/pnnx test_F_log_softmax.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") # ncnn inference import test_F_log_softmax_ncnn diff --git a/tools/pnnx/tests/ncnn/test_F_logsigmoid.py b/tools/pnnx/tests/ncnn/test_F_logsigmoid.py index 023652f7a8a..7d2b304a7b2 100644 --- a/tools/pnnx/tests/ncnn/test_F_logsigmoid.py +++ b/tools/pnnx/tests/ncnn/test_F_logsigmoid.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.logsigmoid(x) y = F.logsigmoid(y) z = F.logsigmoid(z) diff --git a/tools/pnnx/tests/ncnn/test_F_mish.py b/tools/pnnx/tests/ncnn/test_F_mish.py index bf51cf81d52..716f5a009cf 100644 --- a/tools/pnnx/tests/ncnn/test_F_mish.py +++ b/tools/pnnx/tests/ncnn/test_F_mish.py @@ -27,6 +27,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.mish(x) y = mish_forward_0(y) z = mish_forward_1(z) diff --git a/tools/pnnx/tests/ncnn/test_F_prelu.py b/tools/pnnx/tests/ncnn/test_F_prelu.py index 00471747964..fc06db52be9 100644 --- a/tools/pnnx/tests/ncnn/test_F_prelu.py +++ b/tools/pnnx/tests/ncnn/test_F_prelu.py @@ -23,8 +23,12 @@ def __init__(self): self.w4 = nn.Parameter(torch.rand(16)) self.w5 = nn.Parameter(torch.rand(2)) self.w6 = nn.Parameter(torch.rand(3)) + self.w7 = nn.Parameter(torch.rand(12)) def forward(self, x, y, z): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 x = F.prelu(x, self.w4) y = F.prelu(y, self.w5) z = F.prelu(z, self.w6) @@ -38,6 +42,7 @@ def test(): x = torch.rand(1, 16) y = torch.rand(1, 2, 16) z = torch.rand(1, 3, 12, 16) + # w = torch.rand(1, 5, 7, 9, 11) a = net(x, y, z) diff --git a/tools/pnnx/tests/ncnn/test_F_relu.py b/tools/pnnx/tests/ncnn/test_F_relu.py index 6b9c94282cb..b4fad8237a0 100644 --- a/tools/pnnx/tests/ncnn/test_F_relu.py +++ b/tools/pnnx/tests/ncnn/test_F_relu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.relu(x) y = F.relu(y) z = F.relu(z) diff --git a/tools/pnnx/tests/ncnn/test_F_relu6.py b/tools/pnnx/tests/ncnn/test_F_relu6.py index 527029cea20..e93de557bf6 100644 --- a/tools/pnnx/tests/ncnn/test_F_relu6.py +++ b/tools/pnnx/tests/ncnn/test_F_relu6.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.relu6(x) y = F.relu6(y) z = F.relu6(z) diff --git a/tools/pnnx/tests/ncnn/test_F_selu.py b/tools/pnnx/tests/ncnn/test_F_selu.py new file mode 100644 index 00000000000..210b572cdbe --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_selu.py @@ -0,0 +1,67 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.selu(x) + y = F.selu(y) + z = F.selu(z) + w = F.selu(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(16) + y = torch.rand(2, 16) + z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_selu.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_selu.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") + + # ncnn inference + import test_F_selu_ncnn + b = test_F_selu_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_sigmoid.py b/tools/pnnx/tests/ncnn/test_F_sigmoid.py index 757a9f6a3f7..f5e0f39d6e1 100644 --- a/tools/pnnx/tests/ncnn/test_F_sigmoid.py +++ b/tools/pnnx/tests/ncnn/test_F_sigmoid.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.sigmoid(x) y = F.sigmoid(y) z = F.sigmoid(z) diff --git a/tools/pnnx/tests/ncnn/test_F_silu.py b/tools/pnnx/tests/ncnn/test_F_silu.py index c3e5970f2f7..78b75ac57cc 100644 --- a/tools/pnnx/tests/ncnn/test_F_silu.py +++ b/tools/pnnx/tests/ncnn/test_F_silu.py @@ -24,6 +24,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.silu(x) y = F.silu(y) z = silu_forward_0(z) diff --git a/tools/pnnx/tests/ncnn/test_F_softmax.py b/tools/pnnx/tests/ncnn/test_F_softmax.py index 83a5324f49d..12a8cada60a 100644 --- a/tools/pnnx/tests/ncnn/test_F_softmax.py +++ b/tools/pnnx/tests/ncnn/test_F_softmax.py @@ -20,12 +20,16 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.softmax(x, 0) y = F.softmax(y, 1) z = F.softmax(z, 2) - z2 = F.softmax(z, -1) - return x, y, z, z2 + # w = F.softmax(w, -1) TODO + return x, y, z, w def test(): net = Model() @@ -35,16 +39,17 @@ def test(): x = torch.rand(16) y = torch.rand(2, 16) z = torch.rand(3, 12, 16) + w = torch.rand(5, 7, 9, 11) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_F_softmax.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_softmax.pt inputshape=[16],[2,16],[3,12,16]") + os.system("../../src/pnnx test_F_softmax.pt inputshape=[16],[2,16],[3,12,16],[5,7,9,11]") # ncnn inference import test_F_softmax_ncnn diff --git a/tools/pnnx/tests/ncnn/test_F_tanh.py b/tools/pnnx/tests/ncnn/test_F_tanh.py index 9c0b4e2b03b..3ae1b598ed1 100644 --- a/tools/pnnx/tests/ncnn/test_F_tanh.py +++ b/tools/pnnx/tests/ncnn/test_F_tanh.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.tanh(x) y = F.tanh(y) z = F.tanh(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_CELU.py b/tools/pnnx/tests/ncnn/test_nn_CELU.py index 097cc22f7a8..36931cc3b37 100644 --- a/tools/pnnx/tests/ncnn/test_nn_CELU.py +++ b/tools/pnnx/tests/ncnn/test_nn_CELU.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.CELU(alpha=2.0) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_ELU.py b/tools/pnnx/tests/ncnn/test_nn_ELU.py index cf1975f3cc3..b2d7617a70e 100644 --- a/tools/pnnx/tests/ncnn/test_nn_ELU.py +++ b/tools/pnnx/tests/ncnn/test_nn_ELU.py @@ -23,30 +23,36 @@ def __init__(self): self.act_0 = nn.ELU() self.act_1 = nn.ELU(alpha=1.3) - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) - return x, y, z + w = self.act_1(w) + return x, y, z, w def test(): net = Model() net.eval() torch.manual_seed(0) - x = torch.rand(1, 12) - y = torch.rand(1, 12, 64) - z = torch.rand(1, 12, 24, 64) + x = torch.rand(12) + y = torch.rand(12, 64) + z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_nn_ELU.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_ELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + os.system("../../src/pnnx test_nn_ELU.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") # ncnn inference import test_nn_ELU_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_GELU.py b/tools/pnnx/tests/ncnn/test_nn_GELU.py index a3cf990c94c..9c3a139e31f 100644 --- a/tools/pnnx/tests/ncnn/test_nn_GELU.py +++ b/tools/pnnx/tests/ncnn/test_nn_GELU.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.GELU() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_Hardsigmoid.py b/tools/pnnx/tests/ncnn/test_nn_Hardsigmoid.py index a2013583406..de76482506f 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Hardsigmoid.py +++ b/tools/pnnx/tests/ncnn/test_nn_Hardsigmoid.py @@ -22,11 +22,16 @@ def __init__(self): self.act_0 = nn.Hardsigmoid() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) - return x, y, z + w = self.act_0(w) + return x, y, z, w def test(): net = Model() @@ -36,16 +41,17 @@ def test(): x = torch.rand(12) y = torch.rand(12, 64) z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_nn_Hardsigmoid.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_Hardsigmoid.pt inputshape=[12],[12,64],[12,24,64]") + os.system("../../src/pnnx test_nn_Hardsigmoid.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") # ncnn inference import test_nn_Hardsigmoid_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_Hardswish.py b/tools/pnnx/tests/ncnn/test_nn_Hardswish.py index f56005015ef..91656a2e00a 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Hardswish.py +++ b/tools/pnnx/tests/ncnn/test_nn_Hardswish.py @@ -22,11 +22,16 @@ def __init__(self): self.act_0 = nn.Hardswish() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) - return x, y, z + w = self.act_0(w) + return x, y, z, w def test(): net = Model() @@ -36,16 +41,17 @@ def test(): x = torch.rand(12) y = torch.rand(12, 64) z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_nn_Hardswish.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_Hardswish.pt inputshape=[12],[12,64],[12,24,64]") + os.system("../../src/pnnx test_nn_Hardswish.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") # ncnn inference import test_nn_Hardswish_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_Hardtanh.py b/tools/pnnx/tests/ncnn/test_nn_Hardtanh.py index d8894aa2faa..ea342d89c47 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Hardtanh.py +++ b/tools/pnnx/tests/ncnn/test_nn_Hardtanh.py @@ -23,11 +23,16 @@ def __init__(self): self.act_0 = nn.Hardtanh() self.act_1 = nn.Hardtanh(-0.2, 0.2) - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) - return x, y, z + w = self.act_1(w) + return x, y, z, w def test(): net = Model() @@ -37,16 +42,17 @@ def test(): x = torch.rand(12) y = torch.rand(12, 64) z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_nn_Hardtanh.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_Hardtanh.pt inputshape=[12],[12,64],[12,24,64]") + os.system("../../src/pnnx test_nn_Hardtanh.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") # ncnn inference import test_nn_Hardtanh_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_LeakyReLU.py b/tools/pnnx/tests/ncnn/test_nn_LeakyReLU.py index 563e8306a33..a0a8b6e549a 100644 --- a/tools/pnnx/tests/ncnn/test_nn_LeakyReLU.py +++ b/tools/pnnx/tests/ncnn/test_nn_LeakyReLU.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.LeakyReLU(negative_slope=-0.24) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py b/tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py index 26f676c9495..d74477dc9cc 100644 --- a/tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py +++ b/tools/pnnx/tests/ncnn/test_nn_LogSigmoid.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.LogSigmoid() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py b/tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py index a02c050e674..27caaa17705 100644 --- a/tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py +++ b/tools/pnnx/tests/ncnn/test_nn_LogSoftmax.py @@ -25,12 +25,16 @@ def __init__(self): self.act_2 = nn.LogSoftmax(dim=2) self.act_3 = nn.LogSoftmax(dim=-1) - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_1(y) z = self.act_2(z) - z2 = self.act_3(z) - return x, y, z, z2 + # w = self.act_3(w) TODO + return x, y, z, w def test(): net = Model() @@ -40,16 +44,17 @@ def test(): x = torch.rand(12) y = torch.rand(12, 64) z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_nn_LogSoftmax.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_LogSoftmax.pt inputshape=[12],[12,64],[12,24,64]") + os.system("../../src/pnnx test_nn_LogSoftmax.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") # ncnn inference import test_nn_LogSoftmax_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_Mish.py b/tools/pnnx/tests/ncnn/test_nn_Mish.py index e5ec2f09b1a..0767eb1b1ca 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Mish.py +++ b/tools/pnnx/tests/ncnn/test_nn_Mish.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Mish() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_PReLU.py b/tools/pnnx/tests/ncnn/test_nn_PReLU.py index 53c9534d5c8..67c7717f734 100644 --- a/tools/pnnx/tests/ncnn/test_nn_PReLU.py +++ b/tools/pnnx/tests/ncnn/test_nn_PReLU.py @@ -24,6 +24,10 @@ def __init__(self): self.prelu_1 = nn.PReLU(num_parameters=1, init=0.12) def forward(self, x, y, z): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + x = self.prelu_0(x) x = self.prelu_1(x) @@ -32,6 +36,7 @@ def forward(self, x, y, z): z = self.prelu_0(z) z = self.prelu_1(z) + return x, y, z def test(): @@ -42,6 +47,7 @@ def test(): x = torch.rand(1, 12) y = torch.rand(1, 12, 64) z = torch.rand(1, 12, 24, 64) + # w = torch.rand(1, 12, 24, 32, 64) a = net(x, y, z) diff --git a/tools/pnnx/tests/ncnn/test_nn_ReLU.py b/tools/pnnx/tests/ncnn/test_nn_ReLU.py index c448bb92bce..b3adf2c7573 100644 --- a/tools/pnnx/tests/ncnn/test_nn_ReLU.py +++ b/tools/pnnx/tests/ncnn/test_nn_ReLU.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.ReLU() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_ReLU6.py b/tools/pnnx/tests/ncnn/test_nn_ReLU6.py index 947794f78a2..9dbea33a36d 100644 --- a/tools/pnnx/tests/ncnn/test_nn_ReLU6.py +++ b/tools/pnnx/tests/ncnn/test_nn_ReLU6.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.ReLU6() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_SELU.py b/tools/pnnx/tests/ncnn/test_nn_SELU.py index 55da9bfedd0..b8cf59006b9 100644 --- a/tools/pnnx/tests/ncnn/test_nn_SELU.py +++ b/tools/pnnx/tests/ncnn/test_nn_SELU.py @@ -22,11 +22,16 @@ def __init__(self): self.act_0 = nn.SELU() - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) - return x, y, z + w = self.act_0(w) + return x, y, z, w def test(): net = Model() @@ -36,16 +41,17 @@ def test(): x = torch.rand(12) y = torch.rand(12, 64) z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_nn_SELU.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_SELU.pt inputshape=[12],[12,64],[12,24,64]") + os.system("../../src/pnnx test_nn_SELU.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") # ncnn inference import test_nn_SELU_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_SiLU.py b/tools/pnnx/tests/ncnn/test_nn_SiLU.py index 68609390f74..3a4d35a67cf 100644 --- a/tools/pnnx/tests/ncnn/test_nn_SiLU.py +++ b/tools/pnnx/tests/ncnn/test_nn_SiLU.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.SiLU() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_Sigmoid.py b/tools/pnnx/tests/ncnn/test_nn_Sigmoid.py index bb7f15cb8f8..246e63084dd 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Sigmoid.py +++ b/tools/pnnx/tests/ncnn/test_nn_Sigmoid.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Sigmoid() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/ncnn/test_nn_Softmax.py b/tools/pnnx/tests/ncnn/test_nn_Softmax.py index d4ca3df0ff2..80c05f5e638 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Softmax.py +++ b/tools/pnnx/tests/ncnn/test_nn_Softmax.py @@ -25,12 +25,16 @@ def __init__(self): self.act_2 = nn.Softmax(dim=2) self.act_3 = nn.Softmax(dim=-1) - def forward(self, x, y, z): + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_1(y) z = self.act_2(z) - z2 = self.act_3(z) - return x, y, z, z2 + # w = self.act_3(w) TODO + return x, y, z, w def test(): net = Model() @@ -40,16 +44,17 @@ def test(): x = torch.rand(12) y = torch.rand(12, 64) z = torch.rand(12, 24, 64) + w = torch.rand(12, 24, 32, 64) - a = net(x, y, z) + a = net(x, y, z, w) # export torchscript - mod = torch.jit.trace(net, (x, y, z)) + mod = torch.jit.trace(net, (x, y, z, w)) mod.save("test_nn_Softmax.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_Softmax.pt inputshape=[12],[12,64],[12,24,64]") + os.system("../../src/pnnx test_nn_Softmax.pt inputshape=[12],[12,64],[12,24,64],[12,24,32,64]") # ncnn inference import test_nn_Softmax_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_Softmax2d.py b/tools/pnnx/tests/ncnn/test_nn_Softmax2d.py index c92537e9034..ddeb8586a39 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Softmax2d.py +++ b/tools/pnnx/tests/ncnn/test_nn_Softmax2d.py @@ -23,6 +23,7 @@ def __init__(self): self.act_0 = nn.Softmax2d() def forward(self, x): + x = x * 2 - 1 x = self.act_0(x) return x diff --git a/tools/pnnx/tests/ncnn/test_nn_Tanh.py b/tools/pnnx/tests/ncnn/test_nn_Tanh.py index 34f0a1ac1cf..a60d1345d23 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Tanh.py +++ b/tools/pnnx/tests/ncnn/test_nn_Tanh.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Tanh() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_F_celu.py b/tools/pnnx/tests/test_F_celu.py index 43b25f8547c..e68f17cda53 100644 --- a/tools/pnnx/tests/test_F_celu.py +++ b/tools/pnnx/tests/test_F_celu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.celu(x) y = F.celu(y, 0.8) z = F.celu(z, 0.5) diff --git a/tools/pnnx/tests/test_F_elu.py b/tools/pnnx/tests/test_F_elu.py index 73047093eb8..ac37aff6de7 100644 --- a/tools/pnnx/tests/test_F_elu.py +++ b/tools/pnnx/tests/test_F_elu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.elu(x) y = F.elu(y, 1.2) z = F.elu(z, -0.6) diff --git a/tools/pnnx/tests/test_F_gelu.py b/tools/pnnx/tests/test_F_gelu.py index 800251f967f..5a065e2261d 100644 --- a/tools/pnnx/tests/test_F_gelu.py +++ b/tools/pnnx/tests/test_F_gelu.py @@ -28,6 +28,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.gelu(x) y = F.gelu(y) z = gelu_forward_0(z) @@ -59,7 +63,7 @@ def test(): b = test_F_gelu_pnnx.test_inference() for a0, b0 in zip(a, b): - if not torch.allclose(a0, b0, 1e-4, 1e-4): + if not torch.allclose(a0, b0, 1e-3, 1e-3): return False return True diff --git a/tools/pnnx/tests/test_F_hardshrink.py b/tools/pnnx/tests/test_F_hardshrink.py index 3f9cee9ec9c..0836c3bcb8f 100644 --- a/tools/pnnx/tests/test_F_hardshrink.py +++ b/tools/pnnx/tests/test_F_hardshrink.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.hardshrink(x) y = F.hardshrink(y, 0.1) z = F.hardshrink(z, 0.22) diff --git a/tools/pnnx/tests/test_F_hardsigmoid.py b/tools/pnnx/tests/test_F_hardsigmoid.py index 07e569cd879..e2e581aed27 100644 --- a/tools/pnnx/tests/test_F_hardsigmoid.py +++ b/tools/pnnx/tests/test_F_hardsigmoid.py @@ -37,6 +37,10 @@ def __init__(self): self.h_sigmoid = h_sigmoid(); def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.hardsigmoid(x) y = F.hardsigmoid(y) z = self.h_sigmoid(z) diff --git a/tools/pnnx/tests/test_F_hardswish.py b/tools/pnnx/tests/test_F_hardswish.py index f39463b9c36..7bd963529c8 100644 --- a/tools/pnnx/tests/test_F_hardswish.py +++ b/tools/pnnx/tests/test_F_hardswish.py @@ -34,6 +34,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.hardswish(x) y = hardswish_forward_0(y) z = hardswish_forward_1(z) diff --git a/tools/pnnx/tests/test_F_hardtanh.py b/tools/pnnx/tests/test_F_hardtanh.py index 54bcba6e122..4847d7b096f 100644 --- a/tools/pnnx/tests/test_F_hardtanh.py +++ b/tools/pnnx/tests/test_F_hardtanh.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.hardtanh(x) y = F.hardtanh(y, -1, 1) z = F.hardtanh(z, -0.1, 0.1) diff --git a/tools/pnnx/tests/test_F_leaky_relu.py b/tools/pnnx/tests/test_F_leaky_relu.py index 700d78dafa3..852237495e5 100644 --- a/tools/pnnx/tests/test_F_leaky_relu.py +++ b/tools/pnnx/tests/test_F_leaky_relu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.leaky_relu(x) y = F.leaky_relu(y, 0.1) z = F.leaky_relu(z, -0.22) diff --git a/tools/pnnx/tests/test_F_log_softmax.py b/tools/pnnx/tests/test_F_log_softmax.py index 5906a1be996..13181405e71 100644 --- a/tools/pnnx/tests/test_F_log_softmax.py +++ b/tools/pnnx/tests/test_F_log_softmax.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.log_softmax(x, 1) y = F.log_softmax(y, 0) z = F.log_softmax(z, 2) diff --git a/tools/pnnx/tests/test_F_logsigmoid.py b/tools/pnnx/tests/test_F_logsigmoid.py index 096c2ab5254..68daf4819e3 100644 --- a/tools/pnnx/tests/test_F_logsigmoid.py +++ b/tools/pnnx/tests/test_F_logsigmoid.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.logsigmoid(x) y = F.logsigmoid(y) z = F.logsigmoid(z) diff --git a/tools/pnnx/tests/test_F_mish.py b/tools/pnnx/tests/test_F_mish.py index a4bf52c5631..73e40a31df8 100644 --- a/tools/pnnx/tests/test_F_mish.py +++ b/tools/pnnx/tests/test_F_mish.py @@ -27,6 +27,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.mish(x) y = F.mish(y) z = mish_forward_0(z) diff --git a/tools/pnnx/tests/test_F_pairwise_distance.py b/tools/pnnx/tests/test_F_pairwise_distance.py new file mode 100644 index 00000000000..243f61e1b0e --- /dev/null +++ b/tools/pnnx/tests/test_F_pairwise_distance.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + z1 = F.pairwise_distance(x,y,p=1,keepdim=False) + z2 = F.pairwise_distance(x,y,p=2,keepdim=True) + z3 = F.pairwise_distance(x,y) + z4 = F.pairwise_distance(x,y,eps = 1e-3) + return z1,z2,z3,z4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(12, 128, 128) + y = torch.rand(12, 128, 128) + + a0,a1,a2,a3 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_F_pairwise_distance.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_pairwise_distance.pt inputshape=[12,128,128],[12,128,128]") + + # pnnx inference + import test_F_pairwise_distance_pnnx + b0,b1,b2,b3 = test_F_pairwise_distance_pnnx.test_inference() + + return torch.equal(a0,b0) and torch.equal(a1,b1) and torch.equal(a2,b2) and torch.equal(a3,b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_prelu.py b/tools/pnnx/tests/test_F_prelu.py index 60a8e7c797b..759073a1f55 100644 --- a/tools/pnnx/tests/test_F_prelu.py +++ b/tools/pnnx/tests/test_F_prelu.py @@ -26,6 +26,10 @@ def __init__(self): self.w7 = nn.Parameter(torch.rand(1)) def forward(self, x, y, z, w, w0, w1, w2, w3): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.prelu(x, w0) x = F.prelu(x, self.w4) y = F.prelu(y, w1) diff --git a/tools/pnnx/tests/test_F_relu.py b/tools/pnnx/tests/test_F_relu.py index 0319948f7f4..c7337047aa9 100644 --- a/tools/pnnx/tests/test_F_relu.py +++ b/tools/pnnx/tests/test_F_relu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.relu(x) y = F.relu(y) z = F.relu(z) diff --git a/tools/pnnx/tests/test_F_relu6.py b/tools/pnnx/tests/test_F_relu6.py index 147d25002b7..f855be23e9a 100644 --- a/tools/pnnx/tests/test_F_relu6.py +++ b/tools/pnnx/tests/test_F_relu6.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.relu6(x) y = F.relu6(y) z = F.relu6(z) diff --git a/tools/pnnx/tests/test_F_rrelu.py b/tools/pnnx/tests/test_F_rrelu.py index 3dee3fe5e66..bec901db2bd 100644 --- a/tools/pnnx/tests/test_F_rrelu.py +++ b/tools/pnnx/tests/test_F_rrelu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.rrelu(x) y = F.rrelu(y, 0.01) z = F.rrelu(z, 0.125, 0.3333) @@ -37,7 +41,7 @@ def test(): z = torch.rand(1, 3, 12, 16) w = torch.rand(1, 5, 7, 9, 11) - a0, a1, a2, a3 = net(x, y, z, w) + a = net(x, y, z, w) # export torchscript mod = torch.jit.trace(net, (x, y, z, w)) @@ -49,9 +53,12 @@ def test(): # pnnx inference import test_F_rrelu_pnnx - b0, b1, b2, b3 = test_F_rrelu_pnnx.test_inference() + b = test_F_rrelu_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_F_selu.py b/tools/pnnx/tests/test_F_selu.py index 10e3bc4bc57..a4d971904c9 100644 --- a/tools/pnnx/tests/test_F_selu.py +++ b/tools/pnnx/tests/test_F_selu.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.selu(x) y = F.selu(y) z = F.selu(z) diff --git a/tools/pnnx/tests/test_F_sigmoid.py b/tools/pnnx/tests/test_F_sigmoid.py index 282f09ec865..e0a3c4b2379 100644 --- a/tools/pnnx/tests/test_F_sigmoid.py +++ b/tools/pnnx/tests/test_F_sigmoid.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.sigmoid(x) y = F.sigmoid(y) z = F.sigmoid(z) diff --git a/tools/pnnx/tests/test_F_silu.py b/tools/pnnx/tests/test_F_silu.py index 21a124a8e92..b56ea131cfe 100644 --- a/tools/pnnx/tests/test_F_silu.py +++ b/tools/pnnx/tests/test_F_silu.py @@ -24,6 +24,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.silu(x) y = F.silu(y) z = F.silu(z) diff --git a/tools/pnnx/tests/test_F_softmax.py b/tools/pnnx/tests/test_F_softmax.py index c3110f316ef..7092c5b680b 100644 --- a/tools/pnnx/tests/test_F_softmax.py +++ b/tools/pnnx/tests/test_F_softmax.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.softmax(x, 1) y = F.softmax(y, 0) z = F.softmax(z, 2) diff --git a/tools/pnnx/tests/test_F_softmin.py b/tools/pnnx/tests/test_F_softmin.py index 98e4a9e2a21..34dd0812601 100644 --- a/tools/pnnx/tests/test_F_softmin.py +++ b/tools/pnnx/tests/test_F_softmin.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.softmin(x, 1) y = F.softmin(y, 0) z = F.softmin(z, 2) diff --git a/tools/pnnx/tests/test_F_softplus.py b/tools/pnnx/tests/test_F_softplus.py index dd2986b7e71..54ea907a3b8 100644 --- a/tools/pnnx/tests/test_F_softplus.py +++ b/tools/pnnx/tests/test_F_softplus.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.softplus(x) y = F.softplus(y, 2, 1.2) z = F.softplus(z, -0.7, 15) diff --git a/tools/pnnx/tests/test_F_softshrink.py b/tools/pnnx/tests/test_F_softshrink.py index 3bf8443c670..170402785f0 100644 --- a/tools/pnnx/tests/test_F_softshrink.py +++ b/tools/pnnx/tests/test_F_softshrink.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.softshrink(x) y = F.softshrink(y, 0.1) z = F.softshrink(z, 0.22) diff --git a/tools/pnnx/tests/test_F_softsign.py b/tools/pnnx/tests/test_F_softsign.py index eb4b7c41d52..54b79b4aef6 100644 --- a/tools/pnnx/tests/test_F_softsign.py +++ b/tools/pnnx/tests/test_F_softsign.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.softsign(x) y = F.softsign(y) z = F.softsign(z) diff --git a/tools/pnnx/tests/test_F_tanh.py b/tools/pnnx/tests/test_F_tanh.py index 800d558caa6..8a50279c242 100644 --- a/tools/pnnx/tests/test_F_tanh.py +++ b/tools/pnnx/tests/test_F_tanh.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.tanh(x) y = F.tanh(y) z = F.tanh(z) diff --git a/tools/pnnx/tests/test_F_tanhshrink.py b/tools/pnnx/tests/test_F_tanhshrink.py index da2602f9de5..5c0c96e757d 100644 --- a/tools/pnnx/tests/test_F_tanhshrink.py +++ b/tools/pnnx/tests/test_F_tanhshrink.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.tanhshrink(x) y = F.tanhshrink(y) z = F.tanhshrink(z) diff --git a/tools/pnnx/tests/test_F_threshold.py b/tools/pnnx/tests/test_F_threshold.py index 6ad2b8a8913..1b766d3c6e0 100644 --- a/tools/pnnx/tests/test_F_threshold.py +++ b/tools/pnnx/tests/test_F_threshold.py @@ -21,6 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = F.threshold(x, 0.1, 20) y = F.threshold(y, 0.3, 0.4) z = F.threshold(z, 0.1, 20) diff --git a/tools/pnnx/tests/test_nn_CELU.py b/tools/pnnx/tests/test_nn_CELU.py index 5eefc655124..9ecf6307b37 100644 --- a/tools/pnnx/tests/test_nn_CELU.py +++ b/tools/pnnx/tests/test_nn_CELU.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.CELU(alpha=2.0) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/test_nn_ELU.py b/tools/pnnx/tests/test_nn_ELU.py index 3d5d4a0ebd3..a296226af15 100644 --- a/tools/pnnx/tests/test_nn_ELU.py +++ b/tools/pnnx/tests/test_nn_ELU.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.ELU(alpha=1.3) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/test_nn_GELU.py b/tools/pnnx/tests/test_nn_GELU.py index c11f34e6a65..7a2691654dc 100644 --- a/tools/pnnx/tests/test_nn_GELU.py +++ b/tools/pnnx/tests/test_nn_GELU.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.GELU() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_Hardshrink.py b/tools/pnnx/tests/test_nn_Hardshrink.py index c6a42c1e934..b9e0a3f36c5 100644 --- a/tools/pnnx/tests/test_nn_Hardshrink.py +++ b/tools/pnnx/tests/test_nn_Hardshrink.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.Hardshrink(lambd=0.3) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/test_nn_Hardsigmoid.py b/tools/pnnx/tests/test_nn_Hardsigmoid.py index f6ab4ef327b..540202d3835 100644 --- a/tools/pnnx/tests/test_nn_Hardsigmoid.py +++ b/tools/pnnx/tests/test_nn_Hardsigmoid.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Hardsigmoid() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_Hardswish.py b/tools/pnnx/tests/test_nn_Hardswish.py index 73e901f7044..26f7681596c 100644 --- a/tools/pnnx/tests/test_nn_Hardswish.py +++ b/tools/pnnx/tests/test_nn_Hardswish.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Hardswish() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_Hardtanh.py b/tools/pnnx/tests/test_nn_Hardtanh.py index 69886857a0c..cc540b807d9 100644 --- a/tools/pnnx/tests/test_nn_Hardtanh.py +++ b/tools/pnnx/tests/test_nn_Hardtanh.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.Hardtanh(-0.2, 0.2) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/test_nn_LeakyReLU.py b/tools/pnnx/tests/test_nn_LeakyReLU.py index 77621712c21..b34063b9609 100644 --- a/tools/pnnx/tests/test_nn_LeakyReLU.py +++ b/tools/pnnx/tests/test_nn_LeakyReLU.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.LeakyReLU(negative_slope=-0.24) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/test_nn_LogSigmoid.py b/tools/pnnx/tests/test_nn_LogSigmoid.py index d01b30c8612..0e1f68d67e5 100644 --- a/tools/pnnx/tests/test_nn_LogSigmoid.py +++ b/tools/pnnx/tests/test_nn_LogSigmoid.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.LogSigmoid() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_LogSoftmax.py b/tools/pnnx/tests/test_nn_LogSoftmax.py index bff38ed1603..25fd3604a11 100644 --- a/tools/pnnx/tests/test_nn_LogSoftmax.py +++ b/tools/pnnx/tests/test_nn_LogSoftmax.py @@ -26,6 +26,10 @@ def __init__(self): self.act_3 = nn.LogSoftmax(dim=2) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_1(y) z = self.act_2(z) diff --git a/tools/pnnx/tests/test_nn_Mish.py b/tools/pnnx/tests/test_nn_Mish.py index e06ea91d94a..f2b0b36b0a6 100644 --- a/tools/pnnx/tests/test_nn_Mish.py +++ b/tools/pnnx/tests/test_nn_Mish.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Mish() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_PReLU.py b/tools/pnnx/tests/test_nn_PReLU.py index 84752119432..0e748499a34 100644 --- a/tools/pnnx/tests/test_nn_PReLU.py +++ b/tools/pnnx/tests/test_nn_PReLU.py @@ -24,6 +24,11 @@ def __init__(self): self.prelu_1 = nn.PReLU(num_parameters=1, init=0.12) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.prelu_0(x) x = self.prelu_1(x) diff --git a/tools/pnnx/tests/test_nn_RReLU.py b/tools/pnnx/tests/test_nn_RReLU.py index f8929054348..55df6e62859 100644 --- a/tools/pnnx/tests/test_nn_RReLU.py +++ b/tools/pnnx/tests/test_nn_RReLU.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.RReLU(lower=0.1, upper=0.42) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) @@ -40,7 +44,7 @@ def test(): z = torch.rand(1, 12, 24, 64) w = torch.rand(1, 12, 24, 32, 64) - a0, a1, a2, a3 = net(x, y, z, w) + a = net(x, y, z, w) # export torchscript mod = torch.jit.trace(net, (x, y, z, w)) @@ -52,9 +56,12 @@ def test(): # pnnx inference import test_nn_RReLU_pnnx - b0, b1, b2, b3 = test_nn_RReLU_pnnx.test_inference() + b = test_nn_RReLU_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_nn_ReLU.py b/tools/pnnx/tests/test_nn_ReLU.py index 5dddab517f4..cfa92e49608 100644 --- a/tools/pnnx/tests/test_nn_ReLU.py +++ b/tools/pnnx/tests/test_nn_ReLU.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.ReLU() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_ReLU6.py b/tools/pnnx/tests/test_nn_ReLU6.py index 92e08c1a27e..e6d36d4b0d7 100644 --- a/tools/pnnx/tests/test_nn_ReLU6.py +++ b/tools/pnnx/tests/test_nn_ReLU6.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.ReLU6() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_SELU.py b/tools/pnnx/tests/test_nn_SELU.py index 273b0e588b4..bc085a17c07 100644 --- a/tools/pnnx/tests/test_nn_SELU.py +++ b/tools/pnnx/tests/test_nn_SELU.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.SELU() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_SiLU.py b/tools/pnnx/tests/test_nn_SiLU.py index c1a98711b67..92177efdc9b 100644 --- a/tools/pnnx/tests/test_nn_SiLU.py +++ b/tools/pnnx/tests/test_nn_SiLU.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.SiLU() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_Sigmoid.py b/tools/pnnx/tests/test_nn_Sigmoid.py index 28c922c44b7..7bc2417ccd3 100644 --- a/tools/pnnx/tests/test_nn_Sigmoid.py +++ b/tools/pnnx/tests/test_nn_Sigmoid.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Sigmoid() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_Softmax.py b/tools/pnnx/tests/test_nn_Softmax.py index 475385d259c..84db2d884fb 100644 --- a/tools/pnnx/tests/test_nn_Softmax.py +++ b/tools/pnnx/tests/test_nn_Softmax.py @@ -26,6 +26,10 @@ def __init__(self): self.act_3 = nn.Softmax(dim=2) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_1(y) z = self.act_2(z) diff --git a/tools/pnnx/tests/test_nn_Softmax2d.py b/tools/pnnx/tests/test_nn_Softmax2d.py index e75ce61d252..c370903cdb4 100644 --- a/tools/pnnx/tests/test_nn_Softmax2d.py +++ b/tools/pnnx/tests/test_nn_Softmax2d.py @@ -23,6 +23,7 @@ def __init__(self): self.act_0 = nn.Softmax2d() def forward(self, x): + x = x * 2 - 1 x = self.act_0(x) return x diff --git a/tools/pnnx/tests/test_nn_Softmin.py b/tools/pnnx/tests/test_nn_Softmin.py index 8560aef1dd5..c7378136448 100644 --- a/tools/pnnx/tests/test_nn_Softmin.py +++ b/tools/pnnx/tests/test_nn_Softmin.py @@ -26,6 +26,10 @@ def __init__(self): self.act_3 = nn.Softmin(dim=2) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_1(y) z = self.act_2(z) diff --git a/tools/pnnx/tests/test_nn_Softplus.py b/tools/pnnx/tests/test_nn_Softplus.py index b95f6826bbb..cf7a9310905 100644 --- a/tools/pnnx/tests/test_nn_Softplus.py +++ b/tools/pnnx/tests/test_nn_Softplus.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.Softplus(beta=0.7, threshold=15) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/test_nn_Softshrink.py b/tools/pnnx/tests/test_nn_Softshrink.py index db0ad788e98..6f74888264c 100644 --- a/tools/pnnx/tests/test_nn_Softshrink.py +++ b/tools/pnnx/tests/test_nn_Softshrink.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.Softshrink(lambd=1.3) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/test_nn_Softsign.py b/tools/pnnx/tests/test_nn_Softsign.py index 088933895ff..4c2b347662d 100644 --- a/tools/pnnx/tests/test_nn_Softsign.py +++ b/tools/pnnx/tests/test_nn_Softsign.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Softsign() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_Tanh.py b/tools/pnnx/tests/test_nn_Tanh.py index f9ec1babbeb..7a8fb11f7bb 100644 --- a/tools/pnnx/tests/test_nn_Tanh.py +++ b/tools/pnnx/tests/test_nn_Tanh.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Tanh() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_Tanhshrink.py b/tools/pnnx/tests/test_nn_Tanhshrink.py index 4d1611b9adb..7f2bf924195 100644 --- a/tools/pnnx/tests/test_nn_Tanhshrink.py +++ b/tools/pnnx/tests/test_nn_Tanhshrink.py @@ -23,6 +23,10 @@ def __init__(self): self.act_0 = nn.Tanhshrink() def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_0(z) diff --git a/tools/pnnx/tests/test_nn_Threshold.py b/tools/pnnx/tests/test_nn_Threshold.py index 329b3c3440a..6cf5c0a1115 100644 --- a/tools/pnnx/tests/test_nn_Threshold.py +++ b/tools/pnnx/tests/test_nn_Threshold.py @@ -24,6 +24,10 @@ def __init__(self): self.act_1 = nn.Threshold(0.3, 0.4) def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 x = self.act_0(x) y = self.act_0(y) z = self.act_1(z) diff --git a/tools/pnnx/tests/test_pnnx_fuse_conv3d_batchnorm3d.py b/tools/pnnx/tests/test_pnnx_fuse_conv3d_batchnorm3d.py new file mode 100644 index 00000000000..6a6fea3e458 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_conv3d_batchnorm3d.py @@ -0,0 +1,98 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.conv_0 = nn.Conv3d(in_channels=12, out_channels=16, kernel_size=3) + self.bn_0 = nn.BatchNorm3d(num_features=16) + self.conv_1 = nn.Conv3d(in_channels=16, out_channels=20, kernel_size=(2,4,2), stride=(2,1,2), padding=2, dilation=1) + self.bn_1 = nn.BatchNorm3d(num_features=20) + self.conv_2 = nn.Conv3d(in_channels=20, out_channels=24, kernel_size=(1,3,3), stride=1, padding=(2,4,4), dilation=1, groups=1, bias=False) + self.bn_2 = nn.BatchNorm3d(num_features=24) + if version.parse(torch.__version__) < version.parse('1.9'): + self.conv_3 = nn.Conv3d(in_channels=24, out_channels=28, kernel_size=(5,4,5), stride=1, padding=0, dilation=1, groups=4, bias=True) + else: + self.conv_3 = nn.Conv3d(in_channels=24, out_channels=28, kernel_size=(5,4,5), stride=1, padding='valid', dilation=1, groups=4, bias=True) + self.bn_3 = nn.BatchNorm3d(num_features=28) + if version.parse(torch.__version__) < version.parse('1.9'): + self.conv_4 = nn.Conv3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=(1,2,1), groups=2, bias=False, padding_mode='zeros') + else: + self.conv_4 = nn.Conv3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2,1), groups=2, bias=False, padding_mode='zeros') + self.bn_4 = nn.BatchNorm3d(num_features=32) + if version.parse(torch.__version__) >= version.parse('1.10'): + self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.bn_5 = nn.BatchNorm3d(num_features=32) + self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + self.bn_6 = nn.BatchNorm3d(num_features=28) + #self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') + #self.bn_7 = nn.BatchNorm3d(num_features=24) + + def forward(self, x): + x = self.conv_0(x) + x = self.bn_0(x) + x = self.conv_1(x) + x = self.bn_1(x) + x = self.conv_2(x) + x = self.bn_2(x) + x = self.conv_3(x) + x = self.bn_3(x) + x = self.conv_4(x) + x = self.bn_4(x) + if version.parse(torch.__version__) < version.parse('1.10'): + return x + + x = self.conv_5(x) + x = self.bn_5(x) + x = self.conv_6(x) + x = self.bn_6(x) + #x = self.conv_7(x) + #x = self.bn_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64, 64, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_fuse_conv3d_batchnorm3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_conv3d_batchnorm3d.pt inputshape=[1,12,64,64,64]") + + # pnnx inference + import test_pnnx_fuse_conv3d_batchnorm3d_pnnx + b = test_pnnx_fuse_conv3d_batchnorm3d_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_convtranspose3d_batchnorm3d.py b/tools/pnnx/tests/test_pnnx_fuse_convtranspose3d_batchnorm3d.py new file mode 100644 index 00000000000..996cfc2d902 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_convtranspose3d_batchnorm3d.py @@ -0,0 +1,87 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.deconv_0 = nn.ConvTranspose3d(in_channels=12, out_channels=16, kernel_size=3) + self.bn_0 = nn.BatchNorm3d(num_features=16) + self.deconv_1 = nn.ConvTranspose3d(in_channels=16, out_channels=20, kernel_size=(2,4,2), stride=(2,1,2), padding=2, output_padding=0) + self.bn_1 = nn.BatchNorm3d(num_features=20) + self.deconv_2 = nn.ConvTranspose3d(in_channels=20, out_channels=24, kernel_size=(1,3,3), stride=1, padding=(2,4,4), output_padding=(0,0,0), dilation=1, groups=1, bias=False) + self.bn_2 = nn.BatchNorm3d(num_features=24, eps=1e-1, affine=False) + self.deconv_3 = nn.ConvTranspose3d(in_channels=24, out_channels=28, kernel_size=(5,4,5), stride=2, padding=0, output_padding=(0,1,0), dilation=1, groups=4, bias=True) + self.bn_3 = nn.BatchNorm3d(num_features=28, eps=1e-1, affine=False) + self.deconv_4 = nn.ConvTranspose3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, output_padding=0, dilation=(1,2,2), groups=2, bias=False) + self.bn_4 = nn.BatchNorm3d(num_features=32) + self.deconv_5 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, output_padding=1, dilation=1, groups=32, bias=True) + self.bn_5 = nn.BatchNorm3d(num_features=32) + self.deconv_6 = nn.ConvTranspose3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) + self.bn_6 = nn.BatchNorm3d(num_features=28, affine=True) + self.deconv_7 = nn.ConvTranspose3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6,6), output_padding=(1,0,0), dilation=2, groups=1, bias=True) + self.bn_7 = nn.BatchNorm3d(num_features=24, affine=True) + + def forward(self, x): + x = self.deconv_0(x) + x = self.bn_0(x) + x = self.deconv_1(x) + x = self.bn_1(x) + x = self.deconv_2(x) + x = self.bn_2(x) + x = self.deconv_3(x) + x = self.bn_3(x) + x = self.deconv_4(x) + x = self.bn_4(x) + x = self.deconv_5(x) + x = self.bn_5(x) + x = self.deconv_6(x) + x = self.bn_6(x) + x = self.deconv_7(x) + x = self.bn_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 10, 10, 10) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_fuse_convtranspose3d_batchnorm3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_convtranspose3d_batchnorm3d.pt inputshape=[1,12,10,10,10]") + + # pnnx inference + import test_pnnx_fuse_convtranspose3d_batchnorm3d_pnnx + b = test_pnnx_fuse_convtranspose3d_batchnorm3d_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_view_as_complex.py b/tools/pnnx/tests/test_torch_view_as_complex.py new file mode 100644 index 00000000000..c2cedc537d0 --- /dev/null +++ b/tools/pnnx/tests/test_torch_view_as_complex.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.view_as_complex(x) + y = torch.view_as_complex(y) + z = torch.view_as_complex(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 2) + y = torch.rand(1, 5, 9, 2) + z = torch.rand(14, 8, 5, 9, 2) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_view_as_complex.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_view_as_complex.pt inputshape=[1,3,2],[1,5,9,2],[14,8,5,9,2]") + + # pnnx inference + import test_torch_view_as_complex_pnnx + b = test_torch_view_as_complex_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) \ No newline at end of file diff --git a/tools/pnnx/tests/test_torch_view_as_real.py b/tools/pnnx/tests/test_torch_view_as_real.py new file mode 100644 index 00000000000..06bbe7de9b1 --- /dev/null +++ b/tools/pnnx/tests/test_torch_view_as_real.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.view_as_real(x) + y = torch.view_as_real(y) + z = torch.view_as_real(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16,dtype=torch.complex64) + y = torch.rand(1, 5, 9, 11,dtype=torch.complex64) + z = torch.rand(14, 8, 5, 9, 10,dtype=torch.complex64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_view_as_real.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_view_as_real.pt inputshape=[1,3,16]c64,[1,5,9,11]c64,[14,8,5,9,10]c64") + + # pnnx inference + import test_torch_view_as_real_pnnx + b = test_torch_view_as_real_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) \ No newline at end of file