-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops #69186
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: Cullen Rhodes (c-rhodes) ChangesThis is used in #69148 when lowering masked tile_store with non-zero pad, see 8589e50 Full diff: https://github.com/llvm/llvm-project/pull/69186.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..9b9dbff10ea2da6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -441,21 +441,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
of a 2-D scalable vector tile at the given index. The type of the 1-D
scalable vector to be moved must match the type of the tile slice. A tile
slice is a 1-D vector of horizontally or vertically contiguous elements
- within a ZA tile. Horizontal tile slices are currently assumed when
- lowering to intrinsics. The updated tile is returned as the result.
+ within a ZA tile. The updated tile is returned as the result.
- Example 1: Move a vector<[16]xi8> into tile at given index.
+ An optional tile slice layout attribute specifies whether the tile slice is
+ horizontal (default) or vertical.
+
+ Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
```
- Example 2: Move a vector<[2]xf64> into tile at given index.
+ Example 2: Move a vector<[2]xf64> into tile vertically at given index.
```mlir
- %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
+ %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
```
}];
let arguments = (ins
- SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
+ SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
+ ArmSME_TileSliceLayoutAttr:$layout);
let results = (outs SMETile:$result);
let extraClassDeclaration = [{
@@ -465,7 +468,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}];
let assemblyFormat = [{
- $vector `,` $tile `,` $tile_slice_index
+ $vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($vector) `into` type($result)
}];
}
@@ -480,21 +483,26 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
let description = [{
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
scalable tile at the given index. A tile slice is a 1-D vector of
- horizontally or vertically contiguous elements within a ZA tile. Horizontal
- tile slices are currently assumed when lowering to intrinsics.
+ horizontally or vertically contiguous elements within a ZA tile.
+
+ An optional tile slice layout attribute specifies whether the tile slice is
+ horizontal (default) or vertical.
- Example 1: Extract `vector<[16]xi8>` from tile at the given index.
+ Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
```
- Example 2: Extract `vector<[2]xf64>` from tile at the given index.
+ Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
```mlir
- %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
```
}];
- let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
+ let arguments = (ins
+ SMETile:$tile, Index:$tile_slice_index,
+ ArmSME_TileSliceLayoutAttr:$layout
+ );
let results = (outs SVEVector:$result);
let extraClassDeclaration = [{
@@ -502,7 +510,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
}];
let assemblyFormat = [{
- $tile `[` $tile_slice_index `]` attr-dict
+ $tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
`:` type($result) `from` type($tile)
}];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 5e13707ea0aa2b9..1231da356f8ed95 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -350,8 +350,7 @@ struct StoreTileSliceToArmSMELowering
}
};
-/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
struct MoveVectorToTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
@@ -388,10 +387,19 @@ struct MoveVectorToTileSliceToArmSMELowering
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
- // Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
- rewriter.create<arm_sme::aarch64_sme_write_horiz>(
- loc, tileI32, tileSliceI32, allActiveMask,
- moveVectorToTileSliceOp.getVector());
+ // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
+ switch (moveVectorToTileSliceOp.getLayout()) {
+ case arm_sme::TileSliceLayout::Horizontal:
+ rewriter.create<arm_sme::aarch64_sme_write_horiz>(
+ loc, tileI32, tileSliceI32, allActiveMask,
+ moveVectorToTileSliceOp.getVector());
+ break;
+ case arm_sme::TileSliceLayout::Vertical:
+ rewriter.create<arm_sme::aarch64_sme_write_vert>(
+ loc, tileI32, tileSliceI32, allActiveMask,
+ moveVectorToTileSliceOp.getVector());
+ break;
+ }
// Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
// 'arm_sme.cast_tile_to_vector' to preserve dataflow.
@@ -402,8 +410,7 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};
-/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
struct MoveTileSliceToVectorArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertOpToLLVMPattern<
@@ -435,10 +442,19 @@ struct MoveTileSliceToVectorArmSMELowering
auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), sliceIndex);
- // Create 'arm_sme.intr.read.horiz' to extract the tile slice.
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
- moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
- tileIdI32, sliceIndexI32);
+ // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
+ switch (moveTileSliceToVector.getLayout()) {
+ case arm_sme::TileSliceLayout::Horizontal:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
+ moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+ tileIdI32, sliceIndexI32);
+ break;
+ case arm_sme::TileSliceLayout::Vertical:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
+ moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+ tileIdI32, sliceIndexI32);
+ break;
+ }
return success();
}
@@ -680,7 +696,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
- arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
+ arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 07485b3ee8ddf86..9074f0a7ee655c1 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
return
}
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ return
+}
//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
@@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
+ return %slice : vector<[1]xi128>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..e5ba81eff836027 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1059,6 +1059,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
return
}
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+ return
+}
//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
@@ -1135,3 +1143,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+ return %slice : vector<[2]xf64>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Btw, would you mind updating the summary to list the Ops that are being updated? Otherwise that information is only available within the code :( I frequently scan git log
to see an overview and it's just easier if the information is there. Ta!
I've updated the summary and will land this as Cullen is currently away. |
This is used in #69148 when lowering masked tile_store with non-zero pad.
This updates:
arm_sme.move_vector_to_tile_slice
arm_sme.move_tile_slice_to_vector