diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 1f48d27aa27b170..13740225749e46c 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp @@ -40,8 +40,9 @@ static Type matchContainerType(Type element, Type container) { /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile /// any vector.contract into multiple smmla instructions with unrolling so long -/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single -/// smmla instruction is emitted. +/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM +/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is +/// necessary, a single smmla instruction is emitted. class LowerContractionToSMMLAPattern : public OpRewritePattern { public: @@ -49,32 +50,35 @@ class LowerContractionToSMMLAPattern LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - // Check index maps that represent M N K in contract. - auto indexingMaps = op.getIndexingMapsArray(); - if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) { - return affineMap.isPermutation() || affineMap.getNumDims() != 3 || - affineMap.getNumResults() != 2; - })) { - return failure(); - } - // Check iterator types for contract. - auto iteratorTypes = op.getIteratorTypesArray(); - if (iteratorTypes.size() != 3 || - iteratorTypes[0] != vector::IteratorType::parallel || - iteratorTypes[1] != vector::IteratorType::parallel || - iteratorTypes[2] != vector::IteratorType::reduction) { - return failure(); - } - // Infer tile sizes from operands; Note: RHS is not transposed. + // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. + // Note: RHS is not transposed. mlir::VectorType lhsType = op.getLhsType(); mlir::VectorType rhsType = op.getRhsType(); - auto dimM = lhsType.getDimSize(0); + auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); auto dimN = rhsType.getDimSize(0); - auto dimK = lhsType.getDimSize(1); - + auto dimK = rhsType.getDimSize(1); + bool isVecmat = dimM == 1 ? true : false; + if (lhsType.getDimSize(lhsType.getRank() - 1) != + rhsType.getDimSize(rhsType.getRank() - 1)) { + return failure(); // dimK mismatch + } // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for // tiling. - if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) { + if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { + return failure(); + } + + // Check iterator types for contract. All iterators except inner-most + // dimension must be parallel. + auto iteratorTypes = op.getIteratorTypesArray(); + if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != + vector::IteratorType::reduction) { + return failure(); + } + if (llvm::any_of(ArrayRef(iteratorTypes).drop_back(1), + [](vector::IteratorType iteratorType) { + return iteratorType != vector::IteratorType::parallel; + })) { return failure(); } @@ -120,11 +124,14 @@ class LowerContractionToSMMLAPattern loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); SmallVector unrolledSize = *op.getShapeForUnroll(); - SmallVector smmlaShape{2, 2, 8}; - SmallVector loopOrder{0, 1, 2}; + SmallVector smmlaShape{2, 8}; + SmallVector loopOrder{0, 1}; + if (unrolledSize.size() == 3) { + smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2); + loopOrder.push_back(2); + } for (SmallVector offsets : StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { - // Helper to compute the new shape of each operand and extract the slice. auto extractOperand = [&](Value operand, AffineMap permutationMap, ArrayRef operandOffsets) { @@ -150,16 +157,40 @@ class LowerContractionToSMMLAPattern Value tiledAcc = extractOperand(op.getAcc(), accPermutationMap, accOffsets); + auto inputElementType = + tiledLhs.getType().cast().getElementType(); + auto accElementType = + tiledAcc.getType().cast().getElementType(); + auto inputExpandedType = VectorType::get({2, 8}, inputElementType); + auto outputExpandedType = VectorType::get({2, 2}, accElementType); + + // With vecmat, tiled LHS and ACC will contain only one of 2 necessary + // rows along dimM. Expand their shapes to match the smmla op. + if (isVecmat) { + auto expandForSMMLA = [&](Value tiledOperand, + VectorType expandedTypeType) { + auto emptyOperand = rewriter.create( + loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); + SmallVector offsets( + emptyOperand.getType().cast().getRank(), 0); + SmallVector strides( + tiledOperand.getType().cast().getRank(), 1); + return rewriter.createOrFold( + loc, tiledOperand, emptyOperand, offsets, strides); + }; + tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); + tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); + } + // Collapse tiled operands to 1D vectors required by smmla intrinsic - auto collapsedInputType = VectorType::get( - tiledLhs.getType().cast().getNumElements(), - tiledLhs.getType().cast().getElementType()); - auto collapsedOutputType = VectorType::get( - {4}, tiledAcc.getType().cast().getElementType()); + auto collapsedInputType = + VectorType::get(inputExpandedType.getNumElements(), inputElementType); auto collapsedLhs = rewriter.createOrFold( tiledLhs.getLoc(), collapsedInputType, tiledLhs); auto collapsedRhs = rewriter.createOrFold( tiledRhs.getLoc(), collapsedInputType, tiledRhs); + auto collapsedOutputType = + VectorType::get(outputExpandedType.getNumElements(), accElementType); auto collapsedRes = rewriter.createOrFold( tiledAcc.getLoc(), collapsedOutputType, tiledAcc); @@ -172,6 +203,11 @@ class LowerContractionToSMMLAPattern Value tiledRes = rewriter.createOrFold( smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp); + // With vecmat, only one row of tiled ACC can be inserted inot file result + if (isVecmat) { + tiledRes = rewriter.createOrFold(loc, tiledRes, 0); + } + // Insert the tiled result back into the non tiled result of the // contract op. SmallVector strides( diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir index e2be87453bf6f25..46c4026d13b6603 100644 --- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir +++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir @@ -134,3 +134,127 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } + +// ----- + +// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> { +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8> +// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : vector<8xi32> +// CHECK: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8> +// CHECK: %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32> +// CHECK: %[[VAL_8:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[VAL_13:.*]] = arm_neon.intr.smmla %[[VAL_12]], %[[VAL_10]], %[[VAL_11]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_14:.*]] = vector.shape_cast %[[VAL_13]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_14]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_16:.*]] = vector.insert_strided_slice %[[VAL_15]], %[[VAL_5]] {offsets = [0], strides = [1]} : vector<2xi32> into vector<8xi32> +// CHECK: %[[VAL_17:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8> +// CHECK: %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32> +// CHECK: %[[VAL_19:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_22:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[VAL_24:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_24]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_25]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_16]] {offsets = [2], strides = [1]} : vector<2xi32> into vector<8xi32> +// CHECK: %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8> +// CHECK: %[[VAL_29:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [4], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32> +// CHECK: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_31:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_30]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_33:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[VAL_35:.*]] = arm_neon.intr.smmla %[[VAL_34]], %[[VAL_32]], %[[VAL_33]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_35]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_37:.*]] = vector.extract %[[VAL_36]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_38:.*]] = vector.insert_strided_slice %[[VAL_37]], %[[VAL_27]] {offsets = [4], strides = [1]} : vector<2xi32> into vector<8xi32> +// CHECK: %[[VAL_39:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8> +// CHECK: %[[VAL_40:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [6], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32> +// CHECK: %[[VAL_41:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1]} : vector<8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_42:.*]] = vector.insert_strided_slice %[[VAL_40]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_41]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_44:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_42]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[VAL_46:.*]] = arm_neon.intr.smmla %[[VAL_45]], %[[VAL_43]], %[[VAL_44]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_47:.*]] = vector.shape_cast %[[VAL_46]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_48:.*]] = vector.extract %[[VAL_47]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32> +// CHECK: return %[[VAL_49]] : vector<8xi32> +// CHECK: } +func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> { + %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32> + %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32> + return %res : vector<8xi32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> { +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8> +// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : vector<1x8xi32> +// CHECK: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8> +// CHECK: %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32> +// CHECK: %[[VAL_8:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[VAL_13:.*]] = arm_neon.intr.smmla %[[VAL_12]], %[[VAL_10]], %[[VAL_11]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_14:.*]] = vector.shape_cast %[[VAL_13]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_14]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_16:.*]] = vector.insert_strided_slice %[[VAL_15]], %[[VAL_5]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x8xi32> +// CHECK: %[[VAL_17:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8> +// CHECK: %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32> +// CHECK: %[[VAL_19:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_22:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[VAL_24:.*]] = arm_neon.intr.smmla %[[VAL_23]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_24]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_25]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_16]] {offsets = [0, 2], strides = [1]} : vector<2xi32> into vector<1x8xi32> +// CHECK: %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8> +// CHECK: %[[VAL_29:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 4], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32> +// CHECK: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_31:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_30]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_33:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[VAL_35:.*]] = arm_neon.intr.smmla %[[VAL_34]], %[[VAL_32]], %[[VAL_33]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_35]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_37:.*]] = vector.extract %[[VAL_36]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_38:.*]] = vector.insert_strided_slice %[[VAL_37]], %[[VAL_27]] {offsets = [0, 4], strides = [1]} : vector<2xi32> into vector<1x8xi32> +// CHECK: %[[VAL_39:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8> +// CHECK: %[[VAL_40:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 6], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32> +// CHECK: %[[VAL_41:.*]] = vector.insert_strided_slice %[[VAL_0]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_42:.*]] = vector.insert_strided_slice %[[VAL_40]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_41]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_44:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_42]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[VAL_46:.*]] = arm_neon.intr.smmla %[[VAL_45]], %[[VAL_43]], %[[VAL_44]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_47:.*]] = vector.shape_cast %[[VAL_46]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_48:.*]] = vector.extract %[[VAL_47]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32> +// CHECK: return %[[VAL_49]] : vector<1x8xi32> +// CHECK: } +func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> { + %lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32> + %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32> + return %res : vector<1x8xi32> +}