From 8ea260a0931df774e49fab76fa030a2ca4998f54 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Tue, 31 Oct 2023 13:08:55 +0000 Subject: [PATCH] [mlir][ArmSME] Add mask operand to load_tile_slice (#70655) --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 27 +++-- .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 19 ++- .../Transforms/LegalizeForLLVMExport.cpp | 37 +++--- .../ArmSMEToSCF/arm-sme-to-scf.mlir | 15 ++- mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 76 ++++++------ mlir/test/Dialect/ArmSME/invalid.mlir | 13 ++ mlir/test/Dialect/ArmSME/roundtrip.mlir | 114 +++++++++--------- 7 files changed, 174 insertions(+), 127 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 2f6e52ff2badbe..37a2257a0015ce 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -390,7 +390,15 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { } def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ - AllTypesMatch<["tile", "result"]> + AllTypesMatch<["tile", "result"]>, + TypesMatchWith< + "mask has i1 element type and is a slice of the result", + "result", "mask", + "VectorType(" + "VectorType::Builder(" + "::llvm::cast($_self)" + ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))" + ")">, ]> { let summary = "Tile slice load and update operation"; let description = [{ @@ -406,23 +414,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ dimensions since the operation is scalable, and the element type must be a scalar that matches the element type of the result. + An SSA value `mask` specifies to mask out elements read from the MemRef. + The `mask` type is an `i1` vector with a shape that matches how elements + are read from the MemRef. + Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index. ```mlir - %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref, vector<[16]xi1>, vector<[16]x[16]xi8> ``` Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index. ```mlir - %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> ``` Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index. ```mlir - %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> ``` }]; let arguments = (ins - Arg:$base, + Arg:$base, SVEPredicate:$mask, SMETile:$tile, Variadic:$indices, Index:$tile_slice_index, ArmSME_TileSliceLayoutAttr:$layout ); @@ -438,8 +450,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ }]; let assemblyFormat = [{ - $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)? - attr-dict `:` type($base) `,` type($result) + $base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index + (`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,` + type($result) }]; } diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 0ec51b7430c021..50cc818f1ffc09 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -60,6 +60,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, /// /// AFTER: /// ```mlir +/// %ptrue_s = arith.constant dense : vector<[4]xi1> /// %tile_id = arm_sme.get_tile_id : i32 /// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> /// %vscale = vector.vscale @@ -69,7 +70,8 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, /// %svl_s = arith.muli %min_svl_s, %vscale : index /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 { /// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx], -/// %tile, %tile_slice_idx : memref, vector<[4]x[4]xi32> +/// %ptrue_s, %tile, %tile_slice_idx +/// : memref, vector<[4]xi1>, vector<[4]x[4]xi32> /// } /// ``` struct TileLoadOpConversion : public OpRewritePattern { @@ -77,6 +79,11 @@ struct TileLoadOpConversion : public OpRewritePattern { LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, PatternRewriter &rewriter) const override { + if (tileLoadOp.getMask()) + // TODO: add masked patterns. + return rewriter.notifyMatchFailure( + tileLoadOp, "op has mask, needs masked pattern(s)"); + OpBuilder::InsertionGuard g(rewriter); auto loc = tileLoadOp.getLoc(); auto tileType = tileLoadOp.getVectorType(); @@ -109,6 +116,12 @@ struct TileLoadOpConversion : public OpRewritePattern { rewriter.setInsertionPointToStart(forOp.getBody()); + // Create an 'all true' predicate for the tile slice. + auto predicateType = + VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); + auto allTruePredicate = rewriter.create( + loc, DenseElementsAttr::get(predicateType, true)); + // Create 'arm_sme.load_tile_slice' to load tile slice from memory into // tile. SmallVector memrefIndices; @@ -117,8 +130,8 @@ struct TileLoadOpConversion : public OpRewritePattern { tileLoadOp.getMemRefType().getRank(), tileSliceIndex, numTileSlices, memrefIndices, loc, rewriter); rewriter.create( - loc, tileType, tileLoadOp.getBase(), tile, memrefIndices, - tileSliceIndex, tileLoadOp.getLayout()); + loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile, + memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); rewriter.setInsertionPointAfter(forOp); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index 105f2de207a084..7dd04e25075c8d 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -179,12 +179,7 @@ struct LoadTileSliceToArmSMELowering loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. - auto one = rewriter.create( - loc, rewriter.getI1Type(), - rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); - auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), - /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create(loc, predTy, one); + auto maskOp = loadTileSliceOp.getMask(); auto tileI32 = castTileIDToI32(tile, loc, rewriter); arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout(); @@ -195,24 +190,24 @@ struct LoadTileSliceToArmSMELowering default: llvm_unreachable("unexpected element type!"); case 8: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; case 16: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; case 32: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; case 64: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; case 128: - rewriter.create( - loc, allActiveMask, ptr, tileI32, tileSliceI32); + rewriter.create(loc, maskOp, ptr, + tileI32, tileSliceI32); break; } } else { @@ -220,23 +215,23 @@ struct LoadTileSliceToArmSMELowering default: llvm_unreachable("unexpected element type!"); case 8: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; case 16: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; case 32: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; case 64: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; case 128: - rewriter.create(loc, allActiveMask, ptr, + rewriter.create(loc, maskOp, ptr, tileI32, tileSliceI32); break; } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 4b3020970d6ccc..3fb320c0d219e6 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s +//===----------------------------------------------------------------------===// +// arm_sme.tile_load +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func.func @arm_sme_tile_load_hor( // CHECK-SAME: %[[SRC:.*]]: memref) { // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 @@ -10,8 +14,9 @@ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense : vector<[4]xi1> // CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index -// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]x[4]xi32> +// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> func.func @arm_sme_tile_load_hor(%src : memref) { %c0 = arith.constant 0 : index %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[4]x[4]xi32> @@ -28,6 +33,10 @@ func.func @arm_sme_tile_load_ver(%src : memref) { return } +//===----------------------------------------------------------------------===// +// arm_sme.tile_store +//===----------------------------------------------------------------------===// + // ----- // CHECK-LABEL: func.func @arm_sme_tile_store_hor( @@ -57,6 +66,10 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref) diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir index 9074f0a7ee655c..30ddb3c4686018 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir @@ -8,9 +8,9 @@ // CHECK-LABEL: func.func @arm_sme_load_tile_slice_hor_i8( // CHECK-SAME: %[[SRC:.*]]: memref, +// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>, // CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, // CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index) { -// CHECK: %[[PTRUE_B:.*]] = arith.constant dense : vector<[16]xi1> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 @@ -21,12 +21,12 @@ // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32 // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 -// CHECK: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK: "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () // CHECK: return // CHECK: } -func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } @@ -34,9 +34,9 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector< // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } @@ -44,9 +44,9 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32 // CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } @@ -54,9 +54,9 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64 // CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } @@ -64,9 +64,9 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128 // CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } @@ -74,9 +74,9 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %tile : vec // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } @@ -84,9 +84,9 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } @@ -94,9 +94,9 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %tile : vec // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32 // CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } @@ -104,9 +104,9 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64 // CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } @@ -114,9 +114,9 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8 // CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } @@ -124,9 +124,9 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector< // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } @@ -134,9 +134,9 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } @@ -144,9 +144,9 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } @@ -154,9 +154,9 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128 // CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } @@ -164,9 +164,9 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vec // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } @@ -174,9 +174,9 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } @@ -184,9 +184,9 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vec // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } @@ -194,9 +194,9 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vecto // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () -func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { +func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index dba8b1937936e2..1d6386bbf3828f 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -151,6 +151,19 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref, %pad : f64 return } +//===----------------------------------------------------------------------===// +// arm_sme.load_tile_slice +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that mask has i1 element type and is a slice of the result}} + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.outerproduct //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 90b05c54c58d93..6d0aa48015c145 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -638,173 +638,173 @@ func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memre // ----- -func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]x[16]xi8> +func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } // ----- -func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]x[8]xi16> +func.func @arm_sme_load_tile_slice_hor_i16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]xi1>, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } // ----- -func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[4]x[4]xi32> +func.func @arm_sme_load_tile_slice_hor_i32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[4]xi1>, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } // ----- -func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[2]x[2]xi64> +func.func @arm_sme_load_tile_slice_hor_i64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[2]xi1>, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } // ----- -func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[1]x[1]xi128> +func.func @arm_sme_load_tile_slice_hor_i128(%src : memref, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[1]xi1>, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } // ----- -func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]x[8]xf16> +func.func @arm_sme_load_tile_slice_hor_f16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]xi1>, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } // ----- -func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]x[8]xbf16> +func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } // ----- -func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[4]x[4]xf32> +func.func @arm_sme_load_tile_slice_hor_f32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[4]xi1>, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } // ----- -func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[2]x[2]xf64> +func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[2]xi1>, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[16]x[16]xi8> +func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xi16> +func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xi16> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]x[4]xi32> +func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xi32> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]x[2]xi64> +func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[2]xi1>, vector<[2]x[2]xi64> return } // ----- -func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[1]x[1]xi128> +func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[1]xi1>, vector<[1]x[1]xi128> return } // ----- -func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xf16> +func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xf16> return } // ----- -func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xbf16> +func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[8]xi1>, vector<[8]x[8]xbf16> return } // ----- -func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]x[4]xf32> +func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[4]xi1>, vector<[4]x[4]xf32> return } // ----- -func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]x[2]xf64> +func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[2]xi1>, vector<[2]x[2]xf64> return } // ----- /// Layout is optional and horizontal is the default, verify it's still parsed. -func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]x[16]xi8> +func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]xi1>, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout : memref, vector<[16]xi1>, vector<[16]x[16]xi8> return }