Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops #69186

Merged
merged 1 commit into from
Oct 25, 2023

Conversation

c-rhodes
Copy link
Collaborator

@c-rhodes c-rhodes commented Oct 16, 2023

This is used in #69148 when lowering masked tile_store with non-zero pad.

This updates:

  • arm_sme.move_vector_to_tile_slice
  • arm_sme.move_tile_slice_to_vector

@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Cullen Rhodes (c-rhodes)

Changes

This is used in #69148 when lowering masked tile_store with non-zero pad, see 8589e50


Full diff: https://github.com/llvm/llvm-project/pull/69186.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+22-14)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+30-13)
  • (modified) mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir (+32)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..9b9dbff10ea2da6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -441,21 +441,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
     of a 2-D scalable vector tile at the given index. The type of the 1-D
     scalable vector to be moved must match the type of the tile slice. A tile
     slice is a 1-D vector of horizontally or vertically contiguous elements
-    within a ZA tile. Horizontal tile slices are currently assumed when
-    lowering to intrinsics. The updated tile is returned as the result.
+    within a ZA tile. The updated tile is returned as the result.
 
-    Example 1: Move a vector<[16]xi8> into tile at given index.
+    An optional tile slice layout attribute specifies whether the tile slice is
+    horizontal (default) or vertical.
+
+    Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
     ```mlir
     %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
     ```
 
-    Example 2: Move a vector<[2]xf64> into tile at given index.
+    Example 2: Move a vector<[2]xf64> into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
+    %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
     ```
   }];
   let arguments = (ins
-      SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
+      SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
+      ArmSME_TileSliceLayoutAttr:$layout);
   let results = (outs SMETile:$result);
 
   let extraClassDeclaration = [{
@@ -465,7 +468,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $vector `,` $tile `,` $tile_slice_index
+    $vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
       attr-dict `:` type($vector) `into` type($result)
   }];
 }
@@ -480,21 +483,26 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   let description = [{
     The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
     scalable tile at the given index. A tile slice is a 1-D vector of
-    horizontally or vertically contiguous elements within a ZA tile. Horizontal
-    tile slices are currently assumed when lowering to intrinsics.
+    horizontally or vertically contiguous elements within a ZA tile.
+
+    An optional tile slice layout attribute specifies whether the tile slice is
+    horizontal (default) or vertical.
 
-    Example 1: Extract `vector<[16]xi8>` from tile at the given index.
+    Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
     ```mlir
     %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
     ```
 
-    Example 2: Extract `vector<[2]xf64>` from tile at the given index.
+    Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
     ```mlir
-    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
     ```
   }];
 
-  let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
+  let arguments = (ins
+    SMETile:$tile, Index:$tile_slice_index,
+    ArmSME_TileSliceLayoutAttr:$layout
+  );
   let results = (outs SVEVector:$result);
 
   let extraClassDeclaration = [{
@@ -502,7 +510,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 
   let assemblyFormat = [{
-      $tile `[` $tile_slice_index `]` attr-dict
+      $tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
       `:` type($result) `from` type($tile)
   }];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 5e13707ea0aa2b9..1231da356f8ed95 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -350,8 +350,7 @@ struct StoreTileSliceToArmSMELowering
   }
 };
 
-/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
 struct MoveVectorToTileSliceToArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
   using ConvertOpToLLVMPattern<
@@ -388,10 +387,19 @@ struct MoveVectorToTileSliceToArmSMELowering
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
 
-    // Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
-    rewriter.create<arm_sme::aarch64_sme_write_horiz>(
-        loc, tileI32, tileSliceI32, allActiveMask,
-        moveVectorToTileSliceOp.getVector());
+    // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
+    switch (moveVectorToTileSliceOp.getLayout()) {
+    case arm_sme::TileSliceLayout::Horizontal:
+      rewriter.create<arm_sme::aarch64_sme_write_horiz>(
+          loc, tileI32, tileSliceI32, allActiveMask,
+          moveVectorToTileSliceOp.getVector());
+      break;
+    case arm_sme::TileSliceLayout::Vertical:
+      rewriter.create<arm_sme::aarch64_sme_write_vert>(
+          loc, tileI32, tileSliceI32, allActiveMask,
+          moveVectorToTileSliceOp.getVector());
+      break;
+    }
 
     // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
     // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
@@ -402,8 +410,7 @@ struct MoveVectorToTileSliceToArmSMELowering
   }
 };
 
-/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
 struct MoveTileSliceToVectorArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
   using ConvertOpToLLVMPattern<
@@ -435,10 +442,19 @@ struct MoveTileSliceToVectorArmSMELowering
     auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
         loc, rewriter.getI32Type(), sliceIndex);
 
-    // Create 'arm_sme.intr.read.horiz' to extract the tile slice.
-    rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
-        moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-        tileIdI32, sliceIndexI32);
+    // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
+    switch (moveTileSliceToVector.getLayout()) {
+    case arm_sme::TileSliceLayout::Horizontal:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
+          moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+          tileIdI32, sliceIndexI32);
+      break;
+    case arm_sme::TileSliceLayout::Vertical:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
+          moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+          tileIdI32, sliceIndexI32);
+      break;
+    }
 
     return success();
   }
@@ -680,7 +696,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
-      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
+      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
       arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 07485b3ee8ddf86..9074f0a7ee655c1 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  return
+}
 
 //===----------------------------------------------------------------------===//
 // arm_sme.move_tile_slice_to_vector
@@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
+  return %slice : vector<[1]xi128>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..e5ba81eff836027 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1059,6 +1059,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
   return
 }
 
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
+  // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+  return
+}
 
 //===----------------------------------------------------------------------===//
 // arm_sme.move_tile_slice_to_vector
@@ -1135,3 +1143,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+  return %slice : vector<[2]xf64>
+}

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Btw, would you mind updating the summary to list the Ops that are being updated? Otherwise that information is only available within the code :( I frequently scan git log to see an overview and it's just easier if the information is there. Ta!

@MacDue
Copy link
Member

MacDue commented Oct 25, 2023

I've updated the summary and will land this as Cullen is currently away.

@MacDue MacDue merged commit 2f055dd into llvm:main Oct 25, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants