Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add mask operand to load_tile_slice #70655

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,15 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}

def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
AllTypesMatch<["tile", "result"]>
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"mask has i1 element type and is a slice of the result",
"result", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
")">,
]> {
let summary = "Tile slice load and update operation";
let description = [{
Expand All @@ -406,23 +414,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

An SSA value `mask` specifies to mask out elements read from the MemRef.
The `mask` type is an `i1` vector with a shape that matches how elements
are read from the MemRef.

Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
```

Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
```

Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base,
Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
Expand All @@ -438,8 +450,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];

let assemblyFormat = [{
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($result)
$base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index
(`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,`
type($result)
}];
}

Expand Down
19 changes: 16 additions & 3 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
///
/// AFTER:
/// ```mlir
/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
/// %tile_id = arm_sme.get_tile_id : i32
/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
/// %vscale = vector.vscale
Expand All @@ -69,14 +70,20 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
/// %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
/// %ptrue_s, %tile, %tile_slice_idx
/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
if (tileLoadOp.getMask())
// TODO: add masked patterns.
return rewriter.notifyMatchFailure(
tileLoadOp, "op has mask, needs masked pattern(s)");

OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
Expand Down Expand Up @@ -109,6 +116,12 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {

rewriter.setInsertionPointToStart(forOp.getBody());

// Create an 'all true' predicate for the tile slice.
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));

// Create 'arm_sme.load_tile_slice' to load tile slice from memory into
// tile.
SmallVector<Value> memrefIndices;
Expand All @@ -117,8 +130,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
rewriter.create<arm_sme::LoadTileSliceOp>(
loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
tileSliceIndex, tileLoadOp.getLayout());
loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());

rewriter.setInsertionPointAfter(forOp);

Expand Down
37 changes: 16 additions & 21 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,7 @@ struct LoadTileSliceToArmSMELowering
loc, rewriter.getI32Type(), tileSlice);

// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto maskOp = loadTileSliceOp.getMask();

auto tileI32 = castTileIDToI32(tile, loc, rewriter);
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
Expand All @@ -195,48 +190,48 @@ struct LoadTileSliceToArmSMELowering
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
}
} else {
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
tileI32, tileSliceI32);
break;
}
Expand Down
15 changes: 14 additions & 1 deletion mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
// arm_sme.tile_load
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
Expand All @@ -10,8 +14,9 @@
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
Expand All @@ -28,6 +33,10 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
return
}

//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: func.func @arm_sme_tile_store_hor(
Expand Down Expand Up @@ -57,6 +66,10 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
return
}

//===----------------------------------------------------------------------===//
// vector.print
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
Expand Down
Loading