Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Update tile slice layout syntax #69151

Merged

Conversation

c-rhodes
Copy link
Collaborator

This patch prefixes tile slice layout with layout in the assemblyFormat:

  • <vertical> -> layout<vertical>
  • <horizontal> -> layout<horizontal>

The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous.

This affects the the following ops:

  • arm_sme.tile_load
  • arm_sme.tile_store
  • arm_sme.load_tile_slice
  • arm_sme.store_tile_slice

This patch prefixes tile slice layout with `layout` in the
assemblyFormat:

  - <vertical>   -> layout<vertical>
  - <horizontal> -> layout<horizontal>

The reason for this change is the current format doesn't play nicely
with additional optional operands (required to support padding and
masking in an upcoming patch), as it becomes ambiguous.

This affects the the following ops:

  - arm_sme.tile_load
  - arm_sme.tile_store
  - arm_sme.load_tile_slice
  - arm_sme.store_tile_slice
@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2023

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Cullen Rhodes (c-rhodes)

Changes

This patch prefixes tile slice layout with layout in the assemblyFormat:

  • &lt;vertical&gt; -> layout&lt;vertical&gt;
  • &lt;horizontal&gt; -> layout&lt;horizontal&gt;

The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous.

This affects the the following ops:

  • arm_sme.tile_load
  • arm_sme.tile_store
  • arm_sme.load_tile_slice
  • arm_sme.store_tile_slice

Patch is 55.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69151.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+17-22)
  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+2-2)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+3-3)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+4-4)
  • (modified) mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir (+18-18)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+76-76)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+18-18)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 049c9759d70bf43..dab54b63d8d22be 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -76,6 +76,7 @@ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
 def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
                                           "layout"> {
   let assemblyFormat = "`<` $value `>`";
+  let defaultValue = "TileSliceLayout::Horizontal";
 }
 
 //===----------------------------------------------------------------------===//
@@ -248,19 +249,18 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
 
     Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile = arm_sme.tile_load %base[%c0, %c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
     Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
-    DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
-                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+    ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
 
@@ -274,7 +274,7 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   }];
 
   let assemblyFormat =
-    "$base `[` $indices `]` (`,` $layout^)? attr-dict "
+    "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
       "`:` type($base) `,` type($result)";
 }
 
@@ -296,19 +296,17 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
 
     Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.tile_store %tile, %base[%c0, %c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
     Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
-    Variadic<Index>:$indices,
-    DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
-                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+    Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
   );
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -320,7 +318,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
   }];
 
   let assemblyFormat =
-    "$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
+    "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
       "`:` type($base) `,` type($valueToStore)";
 }
 
@@ -348,19 +346,18 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
 
     Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
     Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from">:$base,
     SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
-    DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
-                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+    ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
 
@@ -374,7 +371,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
+    $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
       attr-dict `:` type($base) `,` type($result)
   }];
 }
@@ -401,19 +398,17 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
 
     Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
     Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
-    Variadic<Index>:$indices,
-    DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
-                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+    Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
   );
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -425,7 +420,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
   }];
 
   let assemblyFormat = [{
-    $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
+    $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
       attr-dict `:` type($base) `,` type($tile)
   }];
 }
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 881cc8575fb4824..0ec51b7430c0213 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///
 ///  BEFORE:
 ///  ```mlir
-///  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical>
+///  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
 ///    : memref<?x?xi32>, vector<[4]x[4]xi32
 ///  ```
 ///
@@ -147,7 +147,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
 ///    arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
-///      <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+///      layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index cbc5e468c729372..d06eb4f5b01c950 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -67,7 +67,7 @@ namespace {
 ///
 /// is converted to:
 ///
-///   arm_sme.tile_load ... <vertical>
+///   arm_sme.tile_load ... layout<vertical>
 struct TransferReadPermutationToArmSMELowering
     : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -368,8 +368,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
 ///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
 ///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
 ///     : memref<?x?xi32>, vector<[4]x[4]xi32>
-///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], <vertical>
-///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
+///     layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///
 /// NOTE: Tranposing via memory is obviously expensive, the current intention
 /// is to avoid the transpose if possible, this is therefore intended as a
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 09f148bcd42f593..4b3020970d6ccc1 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -21,10 +21,10 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
 // -----
 
 // CHECK-LABEL: @arm_sme_tile_load_ver
-// CHECK: arm_sme.load_tile_slice {{.*}} <vertical>
+// CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical>
 func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
@@ -50,10 +50,10 @@ func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
 // -----
 
 // CHECK-LABEL: @arm_sme_tile_store_ver
-// CHECK: arm_sme.store_tile_slice {{.*}} <vertical>
+// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical>
 func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 4c16e5c488a74cd..07485b3ee8ddf86 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -116,7 +116,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
@@ -126,7 +126,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
@@ -136,7 +136,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
@@ -146,7 +146,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
@@ -156,7 +156,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
@@ -166,7 +166,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vec
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
@@ -176,7 +176,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -186,7 +186,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vec
 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
@@ -196,7 +196,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
@@ -316,7 +316,7 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
@@ -326,7 +326,7 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
@@ -336,7 +336,7 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
@@ -346,7 +346,7 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
@@ -356,7 +356,7 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 // CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
@@ -366,7 +366,7 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
@@ -376,7 +376,7 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous.

Could you provide an example? Just to satisfy my curiosity. Ta!

@@ -76,6 +76,7 @@ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
"layout"> {
let assemblyFormat = "`<` $value `>`";
let defaultValue = "TileSliceLayout::Horizontal";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this means that:

DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout

can be replaced with:

    ArmSME_TileSliceLayoutAttr:$layout

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's done in this patch, see below

@c-rhodes
Copy link
Collaborator Author

The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous.

Could you provide an example? Just to satisfy my curiosity. Ta!

See b99b39b which is adding optional padding and mask operands to arm_sme.tile_load. The assembly format is updated to add these, with the current layout format it would be:

  let assemblyFormat =
    "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`,` $layout^)?"
      "attr-dict `:` type($base) `,` type($result)";

which is ambiguous. This adds the layout prefix so it isn't so, which is actually consistent with other optional attributes (e.g. fastmath)

@banach-space
Copy link
Contributor

The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous.

Could you provide an example? Just to satisfy my curiosity. Ta!

See b99b39b which is adding optional padding and mask operands to arm_sme.tile_load. The assembly format is updated to add these, with the current layout format it would be:

  let assemblyFormat =
    "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`,` $layout^)?"
      "attr-dict `:` type($base) `,` type($result)";

which is ambiguous. This adds the layout prefix so it isn't so, which is actually consistent with other optional attributes (e.g. fastmath)

So the change would be:

  // BEFORE
  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
  // AFTER
  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>

I don't quite see the ambiguity, but do prefer the more verbose/explicit form anyway :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, ta

@c-rhodes
Copy link
Collaborator Author

The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous.

Could you provide an example? Just to satisfy my curiosity. Ta!

See b99b39b which is adding optional padding and mask operands to arm_sme.tile_load. The assembly format is updated to add these, with the current layout format it would be:

  let assemblyFormat =
    "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`,` $layout^)?"
      "attr-dict `:` type($base) `,` type($result)";

which is ambiguous. This adds the layout prefix so it isn't so, which is actually consistent with other optional attributes (e.g. fastmath)

So the change would be:

  // BEFORE
  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
  // AFTER
  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>

I don't quite see the ambiguity, but do prefer the more verbose/explicit form anyway :)

the ambiguity comes from having 2 optional groups with leading ,, the generator can't handle this.

@MacDue
Copy link
Member

MacDue commented Oct 16, 2023

the ambiguity comes from having 2 optional groups with leading ,, the generator can't handle this.

As when it sees a , it does not know if what follows should be the padding or the layout. It's not clever enough to look ahead and see the next token must be a layout (or must be padding), so it requires you to use the explicit syntax (#arm_sme.layout<horizontal>).

@banach-space
Copy link
Contributor

banach-space commented Oct 16, 2023

the ambiguity comes from having 2 optional groups with leading ,, the generator can't handle this.

As when it sees a , it does not know if what follows should be the padding or the layout. It's not clever enough to look ahead and see the next token must be a layout (or must be padding), so it requires you to use the explicit syntax (#arm_sme.layout<horizontal>).

Ah, now it makes sense, thank you!

[NIT] , <vertical> -> layout<vertical> --> <vertical> -> layout<vertical>

For the summary :)

@c-rhodes c-rhodes merged commit d86047c into llvm:main Oct 16, 2023
6 checks passed
@c-rhodes c-rhodes deleted the mlir-arm-sme-update-tile-slice-layout-syntax branch October 16, 2023 09:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants