Skip to content

Commit

Permalink
[mlir][ArmSME] Add mask operand to load_tile_slice (#70655)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-rhodes authored Oct 31, 2023
1 parent 4b29e8c commit 8ea260a
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 127 deletions.
27 changes: 20 additions & 7 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,15 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}

def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
AllTypesMatch<["tile", "result"]>
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"mask has i1 element type and is a slice of the result",
"result", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
")">,
]> {
let summary = "Tile slice load and update operation";
let description = [{
Expand All @@ -406,23 +414,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

An SSA value `mask` specifies to mask out elements read from the MemRef.
The `mask` type is an `i1` vector with a shape that matches how elements
are read from the MemRef.

Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
```

Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
```

Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base,
Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
Expand All @@ -438,8 +450,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];

let assemblyFormat = [{
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($result)
$base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index
(`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,`
type($result)
}];
}

Expand Down
19 changes: 16 additions & 3 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
///
/// AFTER:
/// ```mlir
/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
/// %tile_id = arm_sme.get_tile_id : i32
/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
/// %vscale = vector.vscale
Expand All @@ -69,14 +70,20 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
/// %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
/// %ptrue_s, %tile, %tile_slice_idx
/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
if (tileLoadOp.getMask())
// TODO: add masked patterns.
return rewriter.notifyMatchFailure(
tileLoadOp, "op has mask, needs masked pattern(s)");

OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
Expand Down Expand Up @@ -109,6 +116,12 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {

rewriter.setInsertionPointToStart(forOp.getBody());

// Create an 'all true' predicate for the tile slice.
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));

// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
// tile.
SmallVector<Value> memrefIndices;
Expand All @@ -117,8 +130,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
rewriter.create<arm_sme::LoadTileSliceOp>(
loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
tileSliceIndex, tileLoadOp.getLayout());
loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());

rewriter.setInsertionPointAfter(forOp);

Expand Down
37 changes: 16 additions & 21 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,7 @@ struct LoadTileSliceToArmSMELowering
loc, rewriter.getI32Type(), tileSlice);

// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto maskOp = loadTileSliceOp.getMask();

auto tileI32 = castTileIDToI32(tile, loc, rewriter);
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
Expand All @@ -195,48 +190,48 @@ struct LoadTileSliceToArmSMELowering
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
}
} else {
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
}
Expand Down
15 changes: 14 additions & 1 deletion mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
// arm_sme.tile_load
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
Expand All @@ -10,8 +14,9 @@
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
Expand All @@ -28,6 +33,10 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
return
}

//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: func.func @arm_sme_tile_store_hor(
Expand Down Expand Up @@ -57,6 +66,10 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
return
}

//===----------------------------------------------------------------------===//
// vector.print
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
Expand Down
Loading

0 comments on commit 8ea260a

Please sign in to comment.