diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index c68851ce6bb24..b70309360a6fa 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -62,6 +62,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ) set(mlas_platform_preprocess_srcs @@ -311,6 +312,7 @@ else() ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S @@ -320,6 +322,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ) if (NOT APPLE) set(mlas_platform_srcs diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S new file mode 100644 index 0000000000000..e716bb3505819 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmS8S8KernelSmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the smmla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + smmla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + smmla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + smmla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + smmla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.2D[0],x27 + mov v9.2D[1],x27 +.endif + mov v8.4S[2], v9.4S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.4S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.4S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().4S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().4S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.2D[0],x27 + mov v9.2D[1],x27 +.endif + + mov v8.2D[1], v9.2D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.2D[0] + mov x28, v8.2D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().2D[0] + mov x28, v\Vec1Reg\().2D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.4S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.4S[2], w27 +.else + mov x27,#0 + mov v9.2D[0],x27 + mov v9.2D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.2D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.4S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.2D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.4S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.2D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.4S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.2D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.4S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.2D[0],x27 + mov v9.2D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.2D[0],x27 + mov v9.2D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.4S[0], w28 + mov v8.4S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.4S[0] + mov w28, v8.4S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().4S[0] + mov w28, v\Vec3Reg\().4S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.2D[0],x27 + mov v9.2D[1],x27 + mov v11.2D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.2D[1], v11.2D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.2D[0] + mov x28, v10.2D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().2D[0] + mov x28, v\Vec3Reg\().2D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.4S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.4S[2], w27 +.else + mov x27,#0 + mov v9.2D[0],x27 + mov v9.2D[1],x27 + + mov v11.2D[0],x27 + mov v11.2D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.2D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.4S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.2D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.4S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.2D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.4S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.2D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.4S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.2D[0],x27 + mov v9.2D[1],x27 + + mov v11.2D[0],x27 + mov v11.2D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmS8S8KernelSmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmS8S8KernelSmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmS8S8KernelSmmlaFunction Zero + QgemmS8S8KernelSmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 9411882efc82e..75937f948a6d3 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -830,6 +830,7 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index c05d9499e1dfe..1882ae328ccaf 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -485,6 +485,7 @@ Return Value: if (HasI8MMInstruction && ort_use_int8mm()) { this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; } #endif // MLAS_TARGET_ARM64 diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp new file mode 100644 index 0000000000000..0a9b91a9662d7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp @@ -0,0 +1,992 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_smmla.cpp + +Abstract: + + This module implements smmla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON SMMLA routines written in assembly. +// + +extern "C" { + + size_t + MLASCALL + MlasGemmS8S8KernelSmmlaZero( + const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB + ); + + size_t + MLASCALL + MlasGemmS8S8KernelSmmlaAdd( + const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB + ); +} + +struct MLAS_GEMM_S8S8_KERNEL_SMMLA +{ + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef int8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 24, 128, 256 }; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 24, 128, 384 }; +}; + +constexpr size_t MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides; + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmQuantFixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + return ZeroPointB; +} + +template<> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* D_uint8_t, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned + ) +{ + int8_t* D = reinterpret_cast(D_uint8_t); + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + int8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + + const int8_t* a0 = reinterpret_cast(A);; + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a0 + lda * 2; + const int8_t* a3 = a0 + lda * 3; + const int8_t* a4 = a0 + lda * 4; + const int8_t* a5 = a0 + lda * 5; + const int8_t* a6 = a0 + lda * 6; + const int8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + int32x4_t RowSums0 = vmovq_n_s32(0); + int32x4_t RowSums1 = vmovq_n_s32(0); + + while (k >= 16) { + + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + int64x2_t v4 = vld1q_s64(reinterpret_cast(a4)); + a4 += 16; + int64x2_t v5 = vld1q_s64(reinterpret_cast(a5)); + a5 += 16; + int64x2_t v6 = vld1q_s64(reinterpret_cast(a6)); + a6 += 16; + int64x2_t v7 = vld1q_s64(reinterpret_cast(a7)); + a7 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + int64x2_t z4 = vzip1q_s64(v4, v5); + int64x2_t z5 = vzip2q_s64(v4, v5); + int64x2_t z6 = vzip1q_s64(v6, v7); + int64x2_t z7 = vzip2q_s64(v6, v7); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z4)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z6)); + vst1q_s8(&D[64], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[80], vreinterpretq_s8_s64(z3)); + vst1q_s8(&D[96], vreinterpretq_s8_s64(z5)); + vst1q_s8(&D[112], vreinterpretq_s8_s64(z7)); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z4))); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z5))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z6))); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z7))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + int64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + int64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + int64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + int64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + int64x2_t z01= vcombine_s64(v0, v1); + int64x2_t z23= vcombine_s64(v2, v3); + int64x2_t z45= vcombine_s64(v4, v5); + int64x2_t z67= vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + vst1q_s8(&d[32], vmovq_n_s8(0)); + vst1q_s8(&d[48], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01= vcombine_s64(v0, v1); + int64x2_t z23= vcombine_s64(v2, v3); + int64x2_t z45= vcombine_s64(v4, v5); + int64x2_t z67= vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, RowSums0); + vst1q_s32(&RowSumBuffer[4], RowSums1); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a1 + lda; + const int8_t* a3 = a2 + lda; + + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z3)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + int64x2_t z01= vcombine_s64(v0, v1); + int64x2_t z23= vcombine_s64(v2, v3); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + + // + // Copy the remaining bytes with zero padding. + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01= vcombine_s64(v0, v1); + int64x2_t z23= vcombine_s64(v2, v3); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, RowSums); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + + size_t k = CountK; + int32x2_t RowSums = vmov_n_s32(0); + + while (k >= 16) { + + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z1)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 32; + k -= 16; + + } + + while (k >= 8) { + + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + int64x2_t z01= vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 16; + k -= 8; + + } + + if (k > 0) { + + // + // Zero pad the remaining elements to make 8 columns. + // + + int8_t* d = PaddedMatrixAData; + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + while (k > 0) { + + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01= vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + int8x16_t PackedVector = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, RowSums); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + + // No need to pad the rows to 2, the .S takes care of zero pdding + const int8_t* a = reinterpret_cast(A); + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int8x16_t v = vld1q_s8(a); + a += 16; + + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + int8x16_t v = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + + } + + *RowSumBuffer = int32_t(vaddvq_s32(RowSums)); + + } +} + +MLAS_FORCEINLINE +void +MlasGemmS8S8CopyPackBProcessSmmla( + int8_t* D, + int8x8_t BytesRow[8], + int32x4_t ColumnSums[2] + ) +{ + + int8x16_t v02 = vcombine_s8(BytesRow[0], BytesRow[2]); + int8x16_t v13 = vcombine_s8(BytesRow[1], BytesRow[3]); + + int8x16_t v46 = vcombine_s8(BytesRow[4], BytesRow[6]); + int8x16_t v57 = vcombine_s8(BytesRow[5], BytesRow[7]); + + int8x16x2_t zw1 = vzipq_s8(v02, v13); + int16x8x2_t zd1 = vzipq_s16(vreinterpretq_s16_s8(zw1.val[0]), vreinterpretq_s16_s8(zw1.val[1])); + + int8x16x2_t zw2 = vzipq_s8(v46, v57); + int16x8x2_t zd2 = vzipq_s16(vreinterpretq_s16_s8(zw2.val[0]), vreinterpretq_s16_s8(zw2.val[1])); + + int32x4x2_t zd3 = vzipq_s32(vreinterpretq_s32_s16(zd1.val[0]), vreinterpretq_s32_s16(zd2.val[0])); + int32x4x2_t zd4 = vzipq_s32(vreinterpretq_s32_s16(zd1.val[1]), vreinterpretq_s32_s16(zd2.val[1])); + + vst1q_s8(&D[0], vreinterpretq_s8_s32(zd3.val[0])); + vst1q_s8(&D[16], vreinterpretq_s8_s32(zd3.val[1])); + vst1q_s8(&D[32], vreinterpretq_s8_s32(zd4.val[0])); + vst1q_s8(&D[48], vreinterpretq_s8_s32(zd4.val[1])); + + int32x4_t ColSums0L_pada = vmovq_n_s32(0); + ColSums0L_pada = vpadalq_s16(ColSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[0]))); + int32x4_t ColSums0L_ext = vextq_s32(ColSums0L_pada, ColSums0L_pada, 1); + int32x4_t ColSums0L_add = vaddq_s32(ColSums0L_pada, ColSums0L_ext); + int32x2_t ColSums0L = {vdups_laneq_s32(ColSums0L_add, 0), vdups_laneq_s32(ColSums0L_add, 2)}; + + int32x4_t ColSums0H_pada = vmovq_n_s32(0); + ColSums0H_pada = vpadalq_s16(ColSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[1]))); + int32x4_t ColSums0H_ext = vextq_s32(ColSums0H_pada, ColSums0H_pada, 1); + int32x4_t ColSums0H_add = vaddq_s32(ColSums0H_pada, ColSums0H_ext); + int32x2_t ColSums0H = {vdups_laneq_s32(ColSums0H_add, 0), vdups_laneq_s32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_s32(ColumnSums[0], vcombine_s32(ColSums0L, ColSums0H)); + + int32x4_t ColSums1L_pada = vmovq_n_s32(0); + ColSums1L_pada = vpadalq_s16(ColSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[0]))); + int32x4_t ColSums1L_ext = vextq_s32(ColSums1L_pada, ColSums1L_pada, 1); + int32x4_t ColSums1L_add = vaddq_s32(ColSums1L_pada, ColSums1L_ext); + int32x2_t ColSums1L = {vdups_laneq_s32(ColSums1L_add, 0), vdups_laneq_s32(ColSums1L_add, 2)}; + + int32x4_t ColSums1H_pada = vmovq_n_s32(0); + ColSums1H_pada = vpadalq_s16(ColSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[1]))); + int32x4_t ColSums1H_ext = vextq_s32(ColSums1H_pada, ColSums1H_pada, 1); + int32x4_t ColSums1H_add = vaddq_s32(ColSums1H_pada, ColSums1H_ext); + int32x2_t ColSums1H = {vdups_laneq_s32(ColSums1H_add, 0), vdups_laneq_s32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_s32(ColumnSums[1], vcombine_s32(ColSums1L, ColSums1H)); + +} + + +template<> +void +MlasGemmQuantCopyPackB( + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* Dst, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + int8_t* D = reinterpret_cast(Dst); + const int8x16_t ZeroVector = vmovq_n_s8(0); + int8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int32x4_t ColumnSums[2]; + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + while (k >= 8) { + + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = vld1_s8(&b[ldb * 1]); + BytesRow[2] = vld1_s8(&b[ldb * 2]); + BytesRow[3] = vld1_s8(&b[ldb * 3]); + BytesRow[4] = vld1_s8(&b[ldb * 4]); + BytesRow[5] = vld1_s8(&b[ldb * 5]); + BytesRow[6] = vld1_s8(&b[ldb * 6]); + BytesRow[7] = vld1_s8(&b[ldb * 7]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + //Pad k to 8 + + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); + BytesRow[2] = (k >= 3) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); + BytesRow[3] = (k >= 4) ? vld1_s8(&b[ldb * 3]) : vget_low_s8(ZeroVector); + BytesRow[4] = (k >= 5) ? vld1_s8(&b[ldb * 4]) : vget_low_s8(ZeroVector); + BytesRow[5] = (k >= 6) ? vld1_s8(&b[ldb * 5]) : vget_low_s8(ZeroVector); + BytesRow[6] = (k >= 7) ? vld1_s8(&b[ldb * 6]) : vget_low_s8(ZeroVector); + BytesRow[7] = vget_low_s8(ZeroVector); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int8_t PaddedMatrixBData[64]; + int32x4_t ColumnSums[2]; + + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + + const int8_t* bcopy0 = &b[ldb * 0]; + const int8_t* bcopy1 = &b[ldb * 1]; + const int8_t* bcopy2 = &b[ldb * 2]; + const int8_t* bcopy3 = &b[ldb * 3]; + const int8_t* bcopy4 = &b[ldb * 4]; + const int8_t* bcopy5 = &b[ldb * 5]; + const int8_t* bcopy6 = &b[ldb * 6]; + const int8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + int8_t* padded = PaddedMatrixBData; + int8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_s8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_s8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_s8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_s8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_s8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_s8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_s8(&PaddedMatrixBData[56]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + } +} + + +template<> +MLAS_FORCEINLINE +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* A, + const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + + size_t RowsHandled; + + if(ZeroMode) { + RowsHandled = MlasGemmS8S8KernelSmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmS8S8KernelSmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; + +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides.K, + 8 +};