Skip to content

Commit

Permalink
[mlir][ArmSME] Update tile slice layout syntax (#69151)
Browse files Browse the repository at this point in the history
This patch prefixes tile slice layout with `layout` in the
assemblyFormat:

  - `<vertical>`   -> `layout<vertical>`
  - `<horizontal>` -> `layout<horizontal>`

The reason for this change is the current format doesn't play nicely
with additional optional operands, required to support padding and
masking (#69148), as it becomes ambiguous.

This affects the the following ops:

  - arm_sme.tile_load
  - arm_sme.tile_store
  - arm_sme.load_tile_slice
  - arm_sme.store_tile_slice
  • Loading branch information
c-rhodes authored Oct 16, 2023
1 parent c0a7dd4 commit d86047c
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 144 deletions.
39 changes: 17 additions & 22 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
"layout"> {
let assemblyFormat = "`<` $value `>`";
let defaultValue = "TileSliceLayout::Horizontal";
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -248,19 +249,18 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {

Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile = arm_sme.tile_load %base[%c0, %c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
```

Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);

Expand All @@ -274,7 +274,7 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}];

let assemblyFormat =
"$base `[` $indices `]` (`,` $layout^)? attr-dict "
"$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($result)";
}

Expand All @@ -296,19 +296,17 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {

Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
arm_sme.tile_store %tile, %base[%c0, %c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```

Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
Expand All @@ -320,7 +318,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}];

let assemblyFormat =
"$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
"$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($valueToStore)";
}

Expand Down Expand Up @@ -348,19 +346,18 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [

Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
```

Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);

Expand All @@ -374,7 +371,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];

let assemblyFormat = [{
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($result)
}];
}
Expand All @@ -401,19 +398,17 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {

Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```

Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
Expand All @@ -425,7 +420,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];

let assemblyFormat = [{
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($tile)
}];
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
///
/// BEFORE:
/// ```mlir
/// arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical>
/// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
/// : memref<?x?xi32>, vector<[4]x[4]xi32
/// ```
///
Expand All @@ -147,7 +147,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
/// <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace {
///
/// is converted to:
///
/// arm_sme.tile_load ... <vertical>
/// arm_sme.tile_load ... layout<vertical>
struct TransferReadPermutationToArmSMELowering
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
Expand Down Expand Up @@ -368,8 +368,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
/// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
/// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], <vertical>
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
///
/// NOTE: Tranposing via memory is obviously expensive, the current intention
/// is to avoid the transpose if possible, this is therefore intended as a
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
// -----

// CHECK-LABEL: @arm_sme_tile_load_ver
// CHECK: arm_sme.load_tile_slice {{.*}} <vertical>
// CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical>
func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
%tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

Expand All @@ -50,10 +50,10 @@ func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
// -----

// CHECK-LABEL: @arm_sme_tile_store_ver
// CHECK: arm_sme.store_tile_slice {{.*}} <vertical>
// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical>
func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

Expand Down
36 changes: 18 additions & 18 deletions mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vecto
// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}

Expand All @@ -126,7 +126,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<
// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}

Expand All @@ -136,7 +136,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vecto
// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

Expand All @@ -146,7 +146,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vecto
// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}

Expand All @@ -156,7 +156,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vecto
// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}

Expand All @@ -166,7 +166,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vec
// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}

Expand All @@ -176,7 +176,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vecto
// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
}

Expand All @@ -186,7 +186,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vec
// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}

Expand All @@ -196,7 +196,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vecto
// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}

Expand Down Expand Up @@ -316,7 +316,7 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}

Expand All @@ -326,7 +326,7 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}

Expand All @@ -336,7 +336,7 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

Expand All @@ -346,7 +346,7 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}

Expand All @@ -356,7 +356,7 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}

Expand All @@ -366,7 +366,7 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}

Expand All @@ -376,7 +376,7 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
}

Expand All @@ -386,7 +386,7 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}

Expand All @@ -396,7 +396,7 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}

Expand Down
Loading

0 comments on commit d86047c

Please sign in to comment.