diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 049c9759d70bf4..dab54b63d8d22b 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -76,6 +76,7 @@ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [ def ArmSME_TileSliceLayoutAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; + let defaultValue = "TileSliceLayout::Horizontal"; } //===----------------------------------------------------------------------===// @@ -248,19 +249,18 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory. ```mlir - %tile = arm_sme.tile_load %base[%c0, %c0], : memref, vector<[4]x[4]xf32> + %tile = arm_sme.tile_load %base[%c0, %c0] layout : memref, vector<[4]x[4]xf32> ``` Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory. ```mlir - %tile = arm_sme.tile_load %base[%c0, %c0], : memref, vector<[1]x[1]xi128> + %tile = arm_sme.tile_load %base[%c0, %c0] layout : memref, vector<[1]x[1]xi128> ``` }]; let arguments = (ins Arg:$base, Variadic:$indices, - DefaultValuedAttr:$layout + ArmSME_TileSliceLayoutAttr:$layout ); let results = (outs SMETile:$result); @@ -274,7 +274,7 @@ def TileLoadOp : ArmSME_Op<"tile_load"> { }]; let assemblyFormat = - "$base `[` $indices `]` (`,` $layout^)? attr-dict " + "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict " "`:` type($base) `,` type($result)"; } @@ -296,19 +296,17 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory. ```mlir - arm_sme.tile_store %tile, %base[%c0, %c0], : vector<[4]x[4]xf32>, memref + arm_sme.tile_store %tile, %base[%c0, %c0] layout : vector<[4]x[4]xf32>, memref ``` Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory. ```mlir - arm_sme.tile_store %tile, %base[%c0, %c0], : vector<[1]x[1]xi128>, memref + arm_sme.tile_store %tile, %base[%c0, %c0] layout : vector<[1]x[1]xi128>, memref ``` }]; let arguments = (ins SMETile:$valueToStore, Arg:$base, - Variadic:$indices, - DefaultValuedAttr:$layout + Variadic:$indices, ArmSME_TileSliceLayoutAttr:$layout ); let extraClassDeclaration = [{ MemRefType getMemRefType() { @@ -320,7 +318,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { }]; let assemblyFormat = - "$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict " + "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict " "`:` type($base) `,` type($valueToStore)"; } @@ -348,19 +346,18 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index. ```mlir - %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> ``` Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index. ```mlir - %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> ``` }]; let arguments = (ins Arg:$base, SMETile:$tile, Variadic:$indices, Index:$tile_slice_index, - DefaultValuedAttr:$layout + ArmSME_TileSliceLayoutAttr:$layout ); let results = (outs SMETile:$result); @@ -374,7 +371,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ }]; let assemblyFormat = [{ - $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)? + $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)? attr-dict `:` type($base) `,` type($result) }]; } @@ -401,19 +398,17 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory. ```mlir - arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], : vector<[4]x[4]xf32>, memref + arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout : vector<[4]x[4]xf32>, memref ``` Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory. ```mlir - arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], : vector<[1]x[1]xi128>, memref + arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout : vector<[1]x[1]xi128>, memref ``` }]; let arguments = (ins SMETile:$tile, Index:$tile_slice_index, Arg:$base, - Variadic:$indices, - DefaultValuedAttr:$layout + Variadic:$indices, ArmSME_TileSliceLayoutAttr:$layout ); let extraClassDeclaration = [{ MemRefType getMemRefType() { @@ -425,7 +420,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { }]; let assemblyFormat = [{ - $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)? + $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict `:` type($base) `,` type($tile) }]; } diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 881cc8575fb482..0ec51b7430c021 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern { /// /// BEFORE: /// ```mlir -/// arm_sme.tile_store %tile, %dest[%c0, %c0], +/// arm_sme.tile_store %tile, %dest[%c0, %c0] layout /// : memref, vector<[4]x[4]xi32 /// ``` /// @@ -147,7 +147,7 @@ struct TileLoadOpConversion : public OpRewritePattern { /// %svl_s = arith.muli %min_svl_s, %vscale : index /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 { /// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx], -/// : memref, vector<[4]x[4]xi32> +/// layout : memref, vector<[4]x[4]xi32> /// } /// ``` struct TileStoreOpConversion : public OpRewritePattern { diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index cbc5e468c72937..d06eb4f5b01c95 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -67,7 +67,7 @@ namespace { /// /// is converted to: /// -/// arm_sme.tile_load ... +/// arm_sme.tile_load ... layout struct TransferReadPermutationToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -368,8 +368,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern { /// %alloca = memref.alloca(%svl_s, %svl_s) : memref /// %arm_sme.tile_store %src, , %alloca[%c0, %c0] /// : memref, vector<[4]x[4]xi32> -/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], -/// : memref, vector<[4]x[4]xi32> +/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] +/// layout : memref, vector<[4]x[4]xi32> /// /// NOTE: Tranposing via memory is obviously expensive, the current intention /// is to avoid the transpose if possible, this is therefore intended as a diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 09f148bcd42f59..4b3020970d6ccc 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -21,10 +21,10 @@ func.func @arm_sme_tile_load_hor(%src : memref) { // ----- // CHECK-LABEL: @arm_sme_tile_load_ver -// CHECK: arm_sme.load_tile_slice {{.*}} +// CHECK: arm_sme.load_tile_slice {{.*}} layout func.func @arm_sme_tile_load_ver(%src : memref) { %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[4]x[4]xi32> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[4]x[4]xi32> return } @@ -50,10 +50,10 @@ func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref +// CHECK: arm_sme.store_tile_slice {{.*}} layout func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref) { %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[4]x[4]xi32> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[4]x[4]xi32> return } diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir index 4c16e5c488a74c..07485b3ee8ddf8 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir @@ -116,7 +116,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %tile : vecto // CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> return } @@ -126,7 +126,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector< // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xi16> return } @@ -136,7 +136,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vecto // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xi32> return } @@ -146,7 +146,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vecto // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xi64> return } @@ -156,7 +156,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vecto // CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> return } @@ -166,7 +166,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vec // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xf16> return } @@ -176,7 +176,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vecto // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xbf16> return } @@ -186,7 +186,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vec // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> return } @@ -196,7 +196,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vecto // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xf64> return } @@ -316,7 +316,7 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s // CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> return } @@ -326,7 +326,7 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xi16> return } @@ -336,7 +336,7 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xi32> return } @@ -346,7 +346,7 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xi64> return } @@ -356,7 +356,7 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s // CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[1]x[1]xi128> return } @@ -366,7 +366,7 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xf16> return } @@ -376,7 +376,7 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xbf16> return } @@ -386,7 +386,7 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xf32> return } @@ -396,7 +396,7 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xf64> return } diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index f6d19359b8e3af..427154158e797f 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -358,81 +358,81 @@ func.func @arm_sme_tile_load_hor_f64(%src : memref) { // ----- func.func @arm_sme_tile_load_ver_i8(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[16]x[16]xi8> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[16]x[16]xi8> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[16]x[16]xi8> return } // ----- func.func @arm_sme_tile_load_ver_i16(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xi16> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[8]x[8]xi16> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[8]x[8]xi16> return } // ----- func.func @arm_sme_tile_load_ver_i32(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[4]x[4]xi32> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[4]x[4]xi32> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[4]x[4]xi32> return } // ----- func.func @arm_sme_tile_load_ver_i64(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[2]x[2]xi64> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[2]x[2]xi64> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[2]x[2]xi64> return } // ----- func.func @arm_sme_tile_load_ver_i128(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[1]x[1]xi128> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[1]x[1]xi128> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[1]x[1]xi128> return } // ----- func.func @arm_sme_tile_load_ver_f16(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xf16> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[8]x[8]xf16> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[8]x[8]xf16> return } // ----- func.func @arm_sme_tile_load_ver_bf16(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xbf16> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[8]x[8]xbf16> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[8]x[8]xbf16> return } // ----- func.func @arm_sme_tile_load_ver_f32(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[4]x[4]xf32> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[4]x[4]xf32> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[4]x[4]xf32> return } // ----- func.func @arm_sme_tile_load_ver_f64(%src : memref) { - // CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[2]x[2]xf64> + // CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[2]x[2]xf64> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[2]x[2]xf64> return } @@ -442,7 +442,7 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref) { func.func @arm_sme_tile_load_explicit_hor(%src : memref) { // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile = arm_sme.tile_load %src[%c0, %c0], : memref, vector<[16]x[16]xi8> + %tile = arm_sme.tile_load %src[%c0, %c0] layout : memref, vector<[16]x[16]xi8> return } @@ -534,81 +534,81 @@ func.func @arm_sme_tile_store_hor_f64(%tile : vector<[2]x[2]xf64>, %dest : memre // ----- func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[16]x[16]xi8> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[16]x[16]xi8> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[16]x[16]xi8> return } // ----- func.func @arm_sme_tile_store_ver_i16(%tile : vector<[8]x[8]xi16>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[8]x[8]xi16> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[8]x[8]xi16> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[8]x[8]xi16> return } // ----- func.func @arm_sme_tile_store_ver_i32(%tile : vector<[4]x[4]xi32>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[4]x[4]xi32> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[4]x[4]xi32> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[4]x[4]xi32> return } // ----- func.func @arm_sme_tile_store_ver_i64(%tile : vector<[2]x[2]xi64>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[2]x[2]xi64> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[2]x[2]xi64> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[2]x[2]xi64> return } // ----- func.func @arm_sme_tile_store_ver_i128(%tile : vector<[1]x[1]xi128>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[1]x[1]xi128> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[1]x[1]xi128> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[1]x[1]xi128> return } // ----- func.func @arm_sme_tile_store_ver_f16(%tile : vector<[8]x[8]xf16>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[8]x[8]xf16> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[8]x[8]xf16> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[8]x[8]xf16> return } // ----- func.func @arm_sme_tile_store_ver_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[8]x[8]xbf16> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[8]x[8]xbf16> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[8]x[8]xbf16> return } // ----- func.func @arm_sme_tile_store_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[4]x[4]xf32> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[4]x[4]xf32> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[4]x[4]xf32> return } // ----- func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memref) { - // CHECK: arm_sme.tile_store {{.*}}, : memref, vector<[2]x[2]xf64> + // CHECK: arm_sme.tile_store {{.*}} layout : memref, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[2]x[2]xf64> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[2]x[2]xf64> return } @@ -618,7 +618,7 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref) { // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.tile_store %tile, %dest[%c0, %c0], : memref, vector<[16]x[16]xi8> + arm_sme.tile_store %tile, %dest[%c0, %c0] layout : memref, vector<[16]x[16]xi8> return } @@ -710,81 +710,81 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref, %tile : vecto // ----- func.func @arm_sme_load_tile_slice_ver_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[16]x[16]xi8> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> return } // ----- func.func @arm_sme_load_tile_slice_ver_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[8]x[8]xi16> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[8]x[8]xi16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xi16> return } // ----- func.func @arm_sme_load_tile_slice_ver_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[4]x[4]xi32> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[4]x[4]xi32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xi32> return } // ----- func.func @arm_sme_load_tile_slice_ver_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[2]x[2]xi64> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[2]x[2]xi64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xi64> return } // ----- func.func @arm_sme_load_tile_slice_ver_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[1]x[1]xi128> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[1]x[1]xi128> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[1]x[1]xi128> return } // ----- func.func @arm_sme_load_tile_slice_ver_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[8]x[8]xf16> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[8]x[8]xf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xf16> return } // ----- func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[8]x[8]xbf16> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[8]x[8]xbf16> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[8]x[8]xbf16> return } // ----- func.func @arm_sme_load_tile_slice_ver_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[4]x[4]xf32> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[4]x[4]xf32> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[4]x[4]xf32> return } // ----- func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { - // CHECK: arm_sme.load_tile_slice {{.*}}, : memref, vector<[2]x[2]xf64> + // CHECK: arm_sme.load_tile_slice {{.*}} layout : memref, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[2]x[2]xf64> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[2]x[2]xf64> return } @@ -794,7 +794,7 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref, %tile : vecto func.func @arm_sme_load_tile_slice_hor_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, : memref, vector<[16]x[16]xi8> + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout : memref, vector<[16]x[16]xi8> return } @@ -886,81 +886,81 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s // ----- func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[16]x[16]xi8> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> return } // ----- func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[8]x[8]xi16> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xi16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[8]x[8]xi16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xi16> return } // ----- func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[4]x[4]xi32> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]x[4]xi32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[4]x[4]xi32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xi32> return } // ----- func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[2]x[2]xi64> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]x[2]xi64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[2]x[2]xi64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xi64> return } // ----- func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[1]x[1]xi128> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[1]x[1]xi128> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[1]x[1]xi128> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[1]x[1]xi128> return } // ----- func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[8]x[8]xf16> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[8]x[8]xf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xf16> return } // ----- func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[8]x[8]xbf16> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[8]x[8]xbf16> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[8]x[8]xbf16> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[8]x[8]xbf16> return } // ----- func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[4]x[4]xf32> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[4]x[4]xf32> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[4]x[4]xf32> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[4]x[4]xf32> return } // ----- func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { - // CHECK: arm_sme.store_tile_slice {{.*}}, : memref, vector<[2]x[2]xf64> + // CHECK: arm_sme.store_tile_slice {{.*}} layout : memref, vector<[2]x[2]xf64> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[2]x[2]xf64> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[2]x[2]xf64> return } @@ -970,7 +970,7 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index - arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], : memref, vector<[16]x[16]xi8> + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout : memref, vector<[16]x[16]xi8> return } diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir index b2c8fd8e01ac7e..455b47a83e28f4 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -5,7 +5,7 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: @transfer_read_2d_transpose_i8 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[16]x[16]xi8> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[16]x[16]xi8> func.func @transfer_read_2d_transpose_i8(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i8 @@ -17,7 +17,7 @@ func.func @transfer_read_2d_transpose_i8(%src : memref) { // ----- // CHECK-LABEL: @transfer_read_2d_transpose_i16 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xi16> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xi16> func.func @transfer_read_2d_transpose_i16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i16 @@ -29,7 +29,7 @@ func.func @transfer_read_2d_transpose_i16(%src : memref) { // ----- // CHECK-LABEL: @transfer_read_2d_transpose_i32 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[4]x[4]xi32> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xi32> func.func @transfer_read_2d_transpose_i32(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i32 @@ -41,7 +41,7 @@ func.func @transfer_read_2d_transpose_i32(%src : memref) { // ----- // CHECK-LABEL: @transfer_read_2d_transpose_i64 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[2]x[2]xi64> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xi64> func.func @transfer_read_2d_transpose_i64(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i64 @@ -53,7 +53,7 @@ func.func @transfer_read_2d_transpose_i64(%src : memref) { // ----- // CHECK-LABEL: @transfer_read_2d_transpose_i128 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[1]x[1]xi128> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[1]x[1]xi128> func.func @transfer_read_2d_transpose_i128(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i128 @@ -65,7 +65,7 @@ func.func @transfer_read_2d_transpose_i128(%src : memref) { // ----- // CHECK-LABEL: @transfer_read_2d_transpose_f16 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xf16> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xf16> func.func @transfer_read_2d_transpose_f16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f16 @@ -77,7 +77,7 @@ func.func @transfer_read_2d_transpose_f16(%src : memref) { // ----- // CHECK-LABEL: @transfer_read_2d_transpose_bf16 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xbf16> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xbf16> func.func @transfer_read_2d_transpose_bf16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : bf16 @@ -89,7 +89,7 @@ func.func @transfer_read_2d_transpose_bf16(%src : memref) { // ----- // CHECK-LABEL: @transfer_read_2d_transpose_f32 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[4]x[4]xf32> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xf32> func.func @transfer_read_2d_transpose_f32(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f32 @@ -101,7 +101,7 @@ func.func @transfer_read_2d_transpose_f32(%src : memref) { // ----- // CHECK-LABEL: @transfer_read_2d_transpose_f64 -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[2]x[2]xf64> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xf64> func.func @transfer_read_2d_transpose_f64(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 @@ -475,7 +475,7 @@ func.func @splat_vec2d_from_f16(%arg0: f16) { // CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index // CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref // CHECK: arm_sme.tile_store %[[TILE]], %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[16]x[16]xi8> -// CHECK: arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]], : memref, vector<[16]x[16]xi8> +// CHECK: arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] layout : memref, vector<[16]x[16]xi8> func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) { %0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8> "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () @@ -487,7 +487,7 @@ func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) { // CHECK-LABEL: @transpose_i16 // CHECK: arith.constant 8 // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[8]x[8]xi16> -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xi16> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xi16> func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) { %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16> "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> () @@ -499,7 +499,7 @@ func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) { // CHECK-LABEL: @transpose_i32 // CHECK: arith.constant 4 // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[4]x[4]xi32> -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[4]x[4]xi32> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xi32> func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) { %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> () @@ -511,7 +511,7 @@ func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) { // CHECK-LABEL: @transpose_i64 // CHECK: arith.constant 2 // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[2]x[2]xi64> -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[2]x[2]xi64> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xi64> func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) { %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64> "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> () @@ -524,7 +524,7 @@ func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) { // CHECK: %[[VSCALE:.*]] = vector.vscale // CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[1]x[1]xi128> -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[1]x[1]xi128> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[1]x[1]xi128> func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) { %0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128> "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> () @@ -536,7 +536,7 @@ func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) { // CHECK-LABEL: @transpose_f16 // CHECK: arith.constant 8 // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[8]x[8]xf16> -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xf16> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xf16> func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) { %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16> "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> () @@ -548,7 +548,7 @@ func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) { // CHECK-LABEL: @transpose_bf16 // CHECK: arith.constant 8 // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[8]x[8]xbf16> -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[8]x[8]xbf16> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xbf16> func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) { %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16> "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> () @@ -560,7 +560,7 @@ func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) { // CHECK-LABEL: @transpose_f32 // CHECK: arith.constant 4 // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[4]x[4]xf32> -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[4]x[4]xf32> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xf32> func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) { %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32> "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () @@ -572,7 +572,7 @@ func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) { // CHECK-LABEL: @transpose_f64 // CHECK: arith.constant 2 // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[2]x[2]xf64> -// CHECK: arm_sme.tile_load {{.*}}, : memref, vector<[2]x[2]xf64> +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xf64> func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) { %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64> "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir index 8c7d8c954d3847..179e9fa83662ec 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir @@ -63,7 +63,7 @@ func.func @entry() { } // Load tile from "mem1" vertically. - %0 = arm_sme.tile_load %mem1[%c0, %c0], : memref, vector<[4]x[4]xi32> + %0 = arm_sme.tile_load %mem1[%c0, %c0] layout : memref, vector<[4]x[4]xi32> // 1. ORIGINAL HORIZONTAL LAYOUT // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least