-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns #86005
Conversation
d39ae36
to
4259a93
Compare
@llvm/pr-subscribers-mlir-neon @llvm/pr-subscribers-mlir Author: Kojo Acquah (KoolJBlack) ChangesUpdates smmla unrolling patterns to handle vecmat contracts where Since the smmla operates on two Full diff: https://github.com/llvm/llvm-project/pull/86005.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 1f48d27aa27b17..5ec02e6db9b8b7 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -40,8 +40,9 @@ static Type matchContainerType(Type element, Type container) {
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
-/// smmla instruction is emitted.
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
@@ -49,32 +50,28 @@ class LowerContractionToSMMLAPattern
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- // Check index maps that represent M N K in contract.
- auto indexingMaps = op.getIndexingMapsArray();
- if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
- return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
- affineMap.getNumResults() != 2;
- })) {
- return failure();
- }
- // Check iterator types for contract.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() != 3 ||
- iteratorTypes[0] != vector::IteratorType::parallel ||
- iteratorTypes[1] != vector::IteratorType::parallel ||
- iteratorTypes[2] != vector::IteratorType::reduction) {
- return failure();
- }
- // Infer tile sizes from operands; Note: RHS is not transposed.
+ // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+ // Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
- auto dimM = lhsType.getDimSize(0);
+ auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
- auto dimK = lhsType.getDimSize(1);
-
+ auto dimK = rhsType.getDimSize(1);
+ bool isVecmat = dimM == 1 ? true : false;
+ if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+ rhsType.getDimSize(rhsType.getRank() - 1)) {
+ return failure(); // dimK mismatch
+ }
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
- if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+ if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+ return failure();
+ }
+
+ // Check iterator types for contract.
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+ vector::IteratorType::reduction) {
return failure();
}
@@ -120,11 +117,14 @@ class LowerContractionToSMMLAPattern
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
- SmallVector<int64_t> smmlaShape{2, 2, 8};
+ SmallVector<int64_t> smmlaShape{isVecmat ? 1 : 2, 2, 8};
SmallVector<int64_t> loopOrder{0, 1, 2};
+ if (unrolledSize.size() == 2) {
+ smmlaShape = {2, 8};
+ loopOrder = {0, 1};
+ }
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +150,30 @@ class LowerContractionToSMMLAPattern
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+ // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+ // rows along dimM. Broadcast both to the full width
+ if (isVecmat) {
+ auto lhsBroadcastType = VectorType::get(
+ {2, 8}, tiledLhs.getType().cast<ShapedType>().getElementType());
+ tiledLhs = rewriter.create<vector::BroadcastOp>(loc, lhsBroadcastType,
+ tiledLhs);
+ auto accBroadcastType = VectorType::get(
+ {2, 2}, tiledAcc.getType().cast<ShapedType>().getElementType());
+ tiledAcc = rewriter.create<vector::BroadcastOp>(loc, accBroadcastType,
+ tiledAcc);
+ }
+
// Collapse tiled operands to 1D vectors required by smmla intrinsic
auto collapsedInputType = VectorType::get(
tiledLhs.getType().cast<ShapedType>().getNumElements(),
tiledLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType = VectorType::get(
- {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+ auto collapsedOutputType = VectorType::get(
+ tiledAcc.getType().cast<ShapedType>().getNumElements(),
+ tiledAcc.getType().cast<ShapedType>().getElementType());
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
@@ -172,6 +186,11 @@ class LowerContractionToSMMLAPattern
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+ // With vecmat, only one row of tiled ACC can be inserted inot file result
+ if (isVecmat) {
+ tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+ }
+
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index e2be87453bf6f2..b2cc3e611ea3d9 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -134,3 +134,125 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
return %res : vector<4x4xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi32>
+// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [2], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [4], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [4], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [6], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
+// CHECK: %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32>
+// CHECK: return %[[VAL_47]] : vector<8xi32>
+// CHECK: }
+func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
+ %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32>
+ return %res : vector<8xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x8xi32>
+// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [0, 2], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK: %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 4], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK: %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [0, 4], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK: %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 6], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
+// CHECK: %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
+// CHECK: %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<1x2xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
+// CHECK: %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32>
+// CHECK: return %[[VAL_47]] : vector<1x8xi32>
+// CHECK: }
+
+func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> {
+ %lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
+ return %res : vector<1x8xi32>
+}
|
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG! A few comments!
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! Feel free to address whatever nit is remaining before landing.
f38bcaa
to
83c27a0
Compare
83c27a0
to
cd63859
Compare
Updates smmla unrolling patterns to handle vecmat contracts where
dimM=1
. This includes explicit vecmats in the form:<1x8xi8> x <8x8xi8> --> <1x8xi32>
or implied with the leading dim folded:<8xi8> x <8x8xi8> --> <8xi32>
Since the smmla operates on two
<2x8xi8>
input vectors to produce<2x2xi8>
accumulators, half of each 2x2 accumulator tile is dummy data not pertinent to the computation, resulting in half throughput.