Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns #86005

Merged
merged 2 commits into from
Apr 3, 2024

Conversation

KoolJBlack
Copy link
Contributor

@KoolJBlack KoolJBlack commented Mar 20, 2024

Updates smmla unrolling patterns to handle vecmat contracts where dimM=1. This includes explicit vecmats in the form: <1x8xi8> x <8x8xi8> --> <1x8xi32> or implied with the leading dim folded: <8xi8> x <8x8xi8> --> <8xi32>

Since the smmla operates on two <2x8xi8> input vectors to produce <2x2xi8> accumulators, half of each 2x2 accumulator tile is dummy data not pertinent to the computation, resulting in half throughput.

@KoolJBlack KoolJBlack force-pushed the arm_neon_vecmat branch 2 times, most recently from d39ae36 to 4259a93 Compare March 21, 2024 03:22
@KoolJBlack KoolJBlack marked this pull request as ready for review March 21, 2024 03:28
@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-mlir-neon

@llvm/pr-subscribers-mlir

Author: Kojo Acquah (KoolJBlack)

Changes

Updates smmla unrolling patterns to handle vecmat contracts where dimM=1. This includes explicit vecmats in the form: &lt;1x8xi8&gt; x &lt;8x8xi8&gt; --&gt; &lt;1x8xi32&gt; or implied with the leading dim folded: &lt;8xi8&gt; x &lt;8x8xi8&gt; --&gt; &lt;8xi32&gt;

Since the smmla operates on two &lt;2x8xi8&gt; input vectors to produce &lt;2x2xi8&gt; accumulators, half of each 2x2 accumulator tile is dummy data not pertinent to the computation, resulting in half throughput.


Full diff: https://github.com/llvm/llvm-project/pull/86005.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+46-27)
  • (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+122)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 1f48d27aa27b17..5ec02e6db9b8b7 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -40,8 +40,9 @@ static Type matchContainerType(Type element, Type container) {
 
 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
 /// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
-/// smmla instruction is emitted.
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
 class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
@@ -49,32 +50,28 @@ class LowerContractionToSMMLAPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    // Check index maps that represent M N K in contract.
-    auto indexingMaps = op.getIndexingMapsArray();
-    if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
-          return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
-                 affineMap.getNumResults() != 2;
-        })) {
-      return failure();
-    }
-    // Check iterator types for contract.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() != 3 ||
-        iteratorTypes[0] != vector::IteratorType::parallel ||
-        iteratorTypes[1] != vector::IteratorType::parallel ||
-        iteratorTypes[2] != vector::IteratorType::reduction) {
-      return failure();
-    }
-    // Infer tile sizes from operands; Note: RHS is not transposed.
+    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+    // Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
-    auto dimM = lhsType.getDimSize(0);
+    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
-    auto dimK = lhsType.getDimSize(1);
-
+    auto dimK = rhsType.getDimSize(1);
+    bool isVecmat = dimM == 1 ? true : false;
+    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+        rhsType.getDimSize(rhsType.getRank() - 1)) {
+      return failure(); // dimK mismatch
+    }
     // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
     // tiling.
-    if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+      return failure();
+    }
+
+    // Check iterator types for contract.
+    auto iteratorTypes = op.getIteratorTypesArray();
+    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+                                        vector::IteratorType::reduction) {
       return failure();
     }
 
@@ -120,11 +117,14 @@ class LowerContractionToSMMLAPattern
         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
 
     SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
-    SmallVector<int64_t> smmlaShape{2, 2, 8};
+    SmallVector<int64_t> smmlaShape{isVecmat ? 1 : 2, 2, 8};
     SmallVector<int64_t> loopOrder{0, 1, 2};
+    if (unrolledSize.size() == 2) {
+      smmlaShape = {2, 8};
+      loopOrder = {0, 1};
+    }
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +150,30 @@ class LowerContractionToSMMLAPattern
       Value tiledAcc =
           extractOperand(op.getAcc(), accPermutationMap, accOffsets);
 
+      // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+      // rows along dimM. Broadcast both to the full width
+      if (isVecmat) {
+        auto lhsBroadcastType = VectorType::get(
+            {2, 8}, tiledLhs.getType().cast<ShapedType>().getElementType());
+        tiledLhs = rewriter.create<vector::BroadcastOp>(loc, lhsBroadcastType,
+                                                        tiledLhs);
+        auto accBroadcastType = VectorType::get(
+            {2, 2}, tiledAcc.getType().cast<ShapedType>().getElementType());
+        tiledAcc = rewriter.create<vector::BroadcastOp>(loc, accBroadcastType,
+                                                        tiledAcc);
+      }
+
       // Collapse tiled operands to 1D vectors required by smmla intrinsic
       auto collapsedInputType = VectorType::get(
           tiledLhs.getType().cast<ShapedType>().getNumElements(),
           tiledLhs.getType().cast<ShapedType>().getElementType());
-      auto collapsedOutputType = VectorType::get(
-          {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+      auto collapsedOutputType = VectorType::get(
+          tiledAcc.getType().cast<ShapedType>().getNumElements(),
+          tiledAcc.getType().cast<ShapedType>().getElementType());
       auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
 
@@ -172,6 +186,11 @@ class LowerContractionToSMMLAPattern
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
           smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
 
+      // With vecmat, only one row of tiled ACC can be inserted inot file result
+      if (isVecmat) {
+        tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+      }
+
       // Insert the tiled result back into the non tiled result of the
       // contract op.
       SmallVector<int64_t> strides(
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index e2be87453bf6f2..b2cc3e611ea3d9 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -134,3 +134,125 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
   return %res : vector<4x4xi32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_vecmat_unroll(
+// CHECK-SAME:  %[[VAL_0:.*]]: vector<8xi8>,
+// CHECK-SAME:  %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME:  %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [2], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [4], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [4], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [6], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK:  return %[[VAL_47]] : vector<8xi32>
+// CHECK:  }
+func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+  %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32>
+  return %res : vector<8xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(
+// CHECK-SAME:  %[[VAL_0:.*]]: vector<1x8xi8>,
+// CHECK-SAME:  %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME:  %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x8xi32>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [0, 2], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 4], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [0, 4], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 6], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK:  %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK:  return %[[VAL_47]] : vector<1x8xi32>
+// CHECK:  }
+
+func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> {
+  %lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
+  return %res : vector<1x8xi32>
+}

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! A few comments!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Feel free to address whatever nit is remaining before landing.

@KoolJBlack KoolJBlack merged commit c511c90 into llvm:main Apr 3, 2024
3 of 4 checks passed
@KoolJBlack KoolJBlack deleted the arm_neon_vecmat branch April 3, 2024 23:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants