diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index e72064651c5cae..86d1172ac4957b 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -420,38 +420,59 @@ struct TileStoreOpConversion : public OpRewritePattern { auto tileType = tileStoreOp.getVectorType(); auto tileElementType = tileType.getElementType(); - // Create a loop that stores each ZA tile slice from memory. + auto predicateType = + VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); + + Value maskCols; + Value upperBound; + auto maskOp = tileStoreOp.getMask(); + if (maskOp) { + auto createMaskOp = maskOp.getDefiningOp(); + if (!createMaskOp) + return rewriter.notifyMatchFailure( + tileStoreOp, "unsupported mask op, only 'vector.create_mask' is " + "currently supported"); + + auto numRows = createMaskOp.getOperands()[0]; + auto numCols = createMaskOp.getOperands()[1]; + + upperBound = numRows; + maskCols = + rewriter.create(loc, predicateType, numCols); + } else { + // Store all tile slices if no mask. + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + // This describes both the number of ZA tile slices and the number of + // elements in a vector of SVL bits for a given element type (SVL_B, + // SVL_H, + // ..., SVL_Q). + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + + upperBound = numTileSlices; + // Create an 'all true' predicate for the tile slice. + maskCols = rewriter.create( + loc, DenseElementsAttr::get(predicateType, true)); + } + + // Create a loop that stores each (active) active ZA tile slice from memory. auto step = rewriter.create(loc, 1); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); - // This describes both the number of ZA tile slices and the number of - // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, - // ..., SVL_Q). - auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - auto forOp = - rewriter.create(loc, lowerBound, numTileSlices, step); + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); rewriter.setInsertionPointToStart(forOp.getBody()); - // Create an 'all true' predicate for the tile slice. - auto predicateType = - VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); - auto allTruePredicate = rewriter.create( - loc, DenseElementsAttr::get(predicateType, true)); - SmallVector memrefIndices; auto tileSliceIndex = forOp.getInductionVar(); getMemrefIndices(tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(), tileSliceIndex, - numTileSlices, memrefIndices, loc, rewriter); + upperBound, memrefIndices, loc, rewriter); rewriter.replaceOpWithNewOp( - tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, - allTruePredicate, tileStoreOp.getBase(), memrefIndices, - tileStoreOp.getLayout()); + tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols, + tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout()); return success(); } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 55ea56f42c96ed..58c6998870edd9 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -102,9 +102,9 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref : vector<[4]xi1> +// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK: %[[PTRUE_S:.*]] = arith.constant dense : vector<[4]xi1> // CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index // CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref) { @@ -123,6 +123,27 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1> +// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] { +// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index +// CHECK-NEXT: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xi32> +func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> + arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref, vector<[4]x[4]xi32> + return +} + //===----------------------------------------------------------------------===// // vector.print //===----------------------------------------------------------------------===//