-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Update tile slice layout syntax #69151
[mlir][ArmSME] Update tile slice layout syntax #69151
Conversation
This patch prefixes tile slice layout with `layout` in the assemblyFormat: - <vertical> -> layout<vertical> - <horizontal> -> layout<horizontal> The reason for this change is the current format doesn't play nicely with additional optional operands (required to support padding and masking in an upcoming patch), as it becomes ambiguous. This affects the the following ops: - arm_sme.tile_load - arm_sme.tile_store - arm_sme.load_tile_slice - arm_sme.store_tile_slice
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sme Author: Cullen Rhodes (c-rhodes) ChangesThis patch prefixes tile slice layout with
The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous. This affects the the following ops:
Patch is 55.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69151.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 049c9759d70bf43..dab54b63d8d22be 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -76,6 +76,7 @@ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
"layout"> {
let assemblyFormat = "`<` $value `>`";
+ let defaultValue = "TileSliceLayout::Horizontal";
}
//===----------------------------------------------------------------------===//
@@ -248,19 +249,18 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile = arm_sme.tile_load %base[%c0, %c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
```
Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
- DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
- "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+ ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);
@@ -274,7 +274,7 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}];
let assemblyFormat =
- "$base `[` $indices `]` (`,` $layout^)? attr-dict "
+ "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($result)";
}
@@ -296,19 +296,17 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.tile_store %tile, %base[%c0, %c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```
Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
- Variadic<Index>:$indices,
- DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
- "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+ Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
@@ -320,7 +318,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}];
let assemblyFormat =
- "$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
+ "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($valueToStore)";
}
@@ -348,19 +346,18 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
```
Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
- DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
- "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+ ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);
@@ -374,7 +371,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];
let assemblyFormat = [{
- $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
+ $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($result)
}];
}
@@ -401,19 +398,17 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```
Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
- Variadic<Index>:$indices,
- DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
- "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+ Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
@@ -425,7 +420,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];
let assemblyFormat = [{
- $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
+ $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($tile)
}];
}
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 881cc8575fb4824..0ec51b7430c0213 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
///
/// BEFORE:
/// ```mlir
-/// arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical>
+/// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
/// : memref<?x?xi32>, vector<[4]x[4]xi32
/// ```
///
@@ -147,7 +147,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
-/// <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index cbc5e468c729372..d06eb4f5b01c950 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -67,7 +67,7 @@ namespace {
///
/// is converted to:
///
-/// arm_sme.tile_load ... <vertical>
+/// arm_sme.tile_load ... layout<vertical>
struct TransferReadPermutationToArmSMELowering
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -368,8 +368,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
/// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
/// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
-/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], <vertical>
-/// : memref<?x?xi32>, vector<[4]x[4]xi32>
+/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
+/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
///
/// NOTE: Tranposing via memory is obviously expensive, the current intention
/// is to avoid the transpose if possible, this is therefore intended as a
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 09f148bcd42f593..4b3020970d6ccc1 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -21,10 +21,10 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
// -----
// CHECK-LABEL: @arm_sme_tile_load_ver
-// CHECK: arm_sme.load_tile_slice {{.*}} <vertical>
+// CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical>
func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
@@ -50,10 +50,10 @@ func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
// -----
// CHECK-LABEL: @arm_sme_tile_store_ver
-// CHECK: arm_sme.store_tile_slice {{.*}} <vertical>
+// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical>
func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 4c16e5c488a74cd..07485b3ee8ddf86 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -116,7 +116,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vecto
// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
@@ -126,7 +126,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<
// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}
@@ -136,7 +136,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vecto
// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
@@ -146,7 +146,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vecto
// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}
@@ -156,7 +156,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vecto
// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}
@@ -166,7 +166,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vec
// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}
@@ -176,7 +176,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vecto
// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
}
@@ -186,7 +186,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vec
// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}
@@ -196,7 +196,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vecto
// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
@@ -316,7 +316,7 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
@@ -326,7 +326,7 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}
@@ -336,7 +336,7 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
@@ -346,7 +346,7 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}
@@ -356,7 +356,7 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}
@@ -366,7 +366,7 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}
@@ -376,7 +376,7 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ arm_sme...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous.
Could you provide an example? Just to satisfy my curiosity. Ta!
@@ -76,6 +76,7 @@ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [ | |||
def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout, | |||
"layout"> { | |||
let assemblyFormat = "`<` $value `>`"; | |||
let defaultValue = "TileSliceLayout::Horizontal"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this means that:
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
can be replaced with:
ArmSME_TileSliceLayoutAttr:$layout
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, that's done in this patch, see below
See b99b39b which is adding optional padding and mask operands to
which is ambiguous. This adds the |
So the change would be:
I don't quite see the ambiguity, but do prefer the more verbose/explicit form anyway :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, ta
the ambiguity comes from having 2 optional groups with leading |
As when it sees a |
Ah, now it makes sense, thank you! [NIT] For the summary :) |
This patch prefixes tile slice layout with
layout
in the assemblyFormat:<vertical>
->layout<vertical>
<horizontal>
->layout<horizontal>
The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous.
This affects the the following ops: