diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index ac55433fadb2f4..ebe6cd4a62b4c5 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -56,6 +56,9 @@ namespace detail { struct BitmaskEnumStorage; } // namespace detail +/// Predefined constant_mask kinds. +enum class ConstantMaskKind { AllFalse = 0, AllTrue }; + /// Default callback to build a region with a 'vector.yield' terminator with no /// arguments. void buildTerminatedBody(OpBuilder &builder, Location loc); @@ -163,6 +166,11 @@ SmallVector getAsValues(OpBuilder &builder, Location loc, SmallVector getAsConstantIndexOps(ArrayRef values); +/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst * +/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return +/// `std::nullopt`. +std::optional getConstantVscaleMultiplier(Value value); + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index cd19d356a6739d..80ec996e4b8a69 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2362,6 +2362,11 @@ def Vector_ConstantMaskOp : ``` }]; + let builders = [ + // Build with mixed static/dynamic operands. + OpBuilder<(ins "VectorType":$type, "ConstantMaskKind":$kind)> + ]; + let extraClassDeclaration = [{ /// Return the result type of this op. VectorType getVectorType() { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h index 1f7d6411cd5a46..e815e026305fab 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Interfaces/FunctionInterfaces.h" namespace mlir { class MLIRContext; @@ -115,6 +116,22 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter); +// Structure to hold the range of `vector.vscale`. +struct VscaleRange { + unsigned vscaleMin; + unsigned vscaleMax; +}; + +/// Attempts to eliminate redundant vector masks by replacing them with all-true +/// constants at the top of the function (which results in the masks folding +/// away). Note: Currently, this only runs for vector.create_mask ops and +/// requires `vscaleRange`. If `vscaleRange` is not provided this transform does +/// nothing. This is because these redundant masks are much more likely for +/// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size +/// code has static sizes, so simpler folds remove the masks. +void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, + std::optional vscaleRange = {}); + } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2a3b9f2091ab39..9879c90ca490f3 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5749,6 +5749,16 @@ void vector::TransposeOp::getCanonicalizationPatterns( // ConstantMaskOp //===----------------------------------------------------------------------===// +void ConstantMaskOp::build(OpBuilder &builder, OperationState &result, + VectorType type, ConstantMaskKind kind) { + assert(kind == ConstantMaskKind::AllTrue || + kind == ConstantMaskKind::AllFalse); + build(builder, result, type, + kind == ConstantMaskKind::AllTrue + ? type.getShape() + : SmallVector(type.getRank(), 0)); +} + LogicalResult ConstantMaskOp::verify() { auto resultType = llvm::cast(getResult().getType()); // Check the corner case of 0-D vectors first. @@ -5831,6 +5841,21 @@ LogicalResult CreateMaskOp::verify() { return success(); } +std::optional vector::getConstantVscaleMultiplier(Value value) { + if (value.getDefiningOp()) + return 1; + auto mul = value.getDefiningOp(); + if (!mul) + return {}; + auto lhs = mul.getLhs(); + auto rhs = mul.getRhs(); + if (lhs.getDefiningOp()) + return getConstantIntValue(rhs); + if (rhs.getDefiningOp()) + return getConstantIntValue(lhs); + return {}; +} + namespace { /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. @@ -5862,73 +5887,51 @@ class CreateMaskFolder final : public OpRewritePattern { LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { - VectorType retTy = createMaskOp.getResult().getType(); - bool isScalable = retTy.isScalable(); - - // Check every mask operand - for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) { - if (auto cst = getConstantIntValue(operand)) { - // Most basic case - this operand is a constant value. Note that for - // scalable dimensions, CreateMaskOp can be folded only if the - // corresponding operand is negative or zero. - if (retTy.getScalableDims()[opIdx] && *cst > 0) - return failure(); - - continue; - } - - // Non-constant operands are not allowed for non-scalable vectors. - if (!isScalable) - return failure(); - - // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all - // true" mask, so can also be treated as constant. - auto mul = operand.getDefiningOp(); - if (!mul) - return failure(); - auto mulLHS = mul.getRhs(); - auto mulRHS = mul.getLhs(); - bool isOneOpVscale = - (isa(mulLHS.getDefiningOp()) || - isa(mulRHS.getDefiningOp())); - - auto isConstantValMatchingDim = - [=, dim = retTy.getShape()[opIdx]](Value operand) { - auto constantVal = getConstantIntValue(operand); - return (constantVal.has_value() && constantVal.value() == dim); - }; - - bool isOneOpConstantMatchingDim = - isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS); - - if (!isOneOpVscale || !isOneOpConstantMatchingDim) - return failure(); + VectorType maskType = createMaskOp.getVectorType(); + ArrayRef maskTypeDimSizes = maskType.getShape(); + ArrayRef maskTypeDimScalableFlags = maskType.getScalableDims(); + + // Special case: Rank zero shape. + constexpr std::array rankZeroShape{1}; + constexpr std::array rankZeroScalableDims{false}; + if (maskType.getRank() == 0) { + maskTypeDimSizes = rankZeroShape; + maskTypeDimScalableFlags = rankZeroScalableDims; } - // Gather constant mask dimension sizes. - SmallVector maskDimSizes; - maskDimSizes.reserve(createMaskOp->getNumOperands()); - for (auto [operand, maxDimSize] : llvm::zip_equal( - createMaskOp.getOperands(), createMaskOp.getType().getShape())) { - std::optional dimSize = getConstantIntValue(operand); - if (!dimSize) { - // Although not a constant, it is safe to assume that `operand` is - // "vscale * maxDimSize". - maskDimSizes.push_back(maxDimSize); - continue; - } - int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize); - // If one of dim sizes is zero, set all dims to zero. - if (dimSize <= 0) { - maskDimSizes.assign(createMaskOp.getType().getRank(), 0); - break; + // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and + // collect the `constantDims` (for the ConstantMaskOp). + SmallVector constantDims; + for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) { + if (auto intSize = getConstantIntValue(dimSize)) { + // Constant value. + // If the mask dim is non-scalable this can be any value. + // If the mask dim is scalable only zero (all-false) is supported. + if (maskTypeDimScalableFlags[i] && intSize >= 0) + return failure(); + constantDims.push_back(*intSize); + } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) { + // Constant vscale multiple (e.g. 4 x vscale). + // Must be all-true to fold to a ConstantMask. + if (vscaleMultiplier < maskTypeDimSizes[i]) + return failure(); + constantDims.push_back(*vscaleMultiplier); + } else { + return failure(); } - maskDimSizes.push_back(dimSizeVal); } + // Clamp values to constant_mask bounds. + for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes)) + value = std::clamp(value, 0, maskDimSize); + + // If one of dim sizes is zero, set all dims to zero. + if (llvm::is_contained(constantDims, 0)) + constantDims.assign(constantDims.size(), 0); + // Replace 'createMaskOp' with ConstantMaskOp. - rewriter.replaceOpWithNewOp(createMaskOp, retTy, - maskDimSizes); + rewriter.replaceOpWithNewOp(createMaskOp, maskType, + constantDims); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 723b2f62d65d4f..2639a67e1c8b31 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRVectorTransforms VectorTransferSplitRewritePatterns.cpp VectorTransforms.cpp VectorUnroll.cpp + VectorMaskElimination.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp new file mode 100644 index 00000000000000..363108238e5960 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp @@ -0,0 +1,118 @@ +//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +using namespace mlir; +using namespace mlir::vector; +namespace { + +/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask. +/// All-true masks can then be eliminated by simple folds. +LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter, + vector::CreateMaskOp createMaskOp, + VscaleRange vscaleRange) { + auto maskType = createMaskOp.getVectorType(); + auto maskTypeDimScalableFlags = maskType.getScalableDims(); + auto maskTypeDimSizes = maskType.getShape(); + + struct UnknownMaskDim { + size_t position; + Value dimSize; + }; + + // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims + // that are not obviously constant). If any constant dimension is not all-true + // bail out early (as this transform only trying to resolve all-true masks). + // This avoids doing value-bounds anaylis in cases like: + // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>` + // ...where it is known the mask is not all-true by looking at `%c2`. + SmallVector unknownDims; + for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) { + if (auto intSize = getConstantIntValue(dimSize)) { + // Mask not all-true for this dim. + if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i]) + return failure(); + } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) { + // Mask not all-true for this dim. + if (vscaleMultiplier < maskTypeDimSizes[i]) + return failure(); + } else { + // Unknown (without further analysis). + unknownDims.push_back(UnknownMaskDim{i, dimSize}); + } + } + + for (auto [i, dimSize] : unknownDims) { + // Compute the lower bound for the unknown dimension (i.e. the smallest + // value it could be). + FailureOr dimLowerBound = + vector::ScalableValueBoundsConstraintSet::computeScalableBound( + dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax, + presburger::BoundType::LB); + if (failed(dimLowerBound)) + return failure(); + auto dimLowerBoundSize = dimLowerBound->getSize(); + if (failed(dimLowerBoundSize)) + return failure(); + if (dimLowerBoundSize->scalable) { + // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then + // this dim is not all-true. + if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i]) + return failure(); + } else { + // 2. The lower bound, LB, is a constant. + // - If the mask dim size is scalable then this dim is not all-true. + if (maskTypeDimScalableFlags[i]) + return failure(); + // - If LB < the _fixed-size_ mask dim size then this dim is not all-true. + if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i]) + return failure(); + } + } + + // Replace createMaskOp with an all-true constant. This should result in the + // mask being removed in most cases (as xfer ops + vector.mask have folds to + // remove all-true masks). + auto allTrue = rewriter.create( + createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue); + rewriter.replaceAllUsesWith(createMaskOp, allTrue); + return success(); +} + +} // namespace + +namespace mlir::vector { + +void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, + std::optional vscaleRange) { + // TODO: Support fixed-size case. This is less likely to be useful as for + // fixed-size code dimensions are all static so masks tend to fold away. + if (!vscaleRange) + return; + + OpBuilder::InsertionGuard g(rewriter); + + // Build worklist so we can safely insert new ops in + // `resolveAllTrueCreateMaskOp()`. + SmallVector worklist; + function.walk([&](vector::CreateMaskOp createMaskOp) { + worklist.push_back(createMaskOp); + }); + + rewriter.setInsertionPointToStart(&function.front()); + for (auto mask : worklist) + (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange); +} + +} // namespace mlir::vector diff --git a/mlir/test/Dialect/Vector/eliminate-masks.mlir b/mlir/test/Dialect/Vector/eliminate-masks.mlir new file mode 100644 index 00000000000000..0b89b0604faab1 --- /dev/null +++ b/mlir/test/Dialect/Vector/eliminate-masks.mlir @@ -0,0 +1,159 @@ +// RUN: mlir-opt %s -split-input-file -test-eliminate-vector-masks --split-input-file | FileCheck %s + +// This tests a general pattern the vectorizer tends to emit. + +// CHECK-LABEL: @eliminate_redundant_masks_through_insert_and_extracts +// CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [4] : vector<[4]xi1> +// CHECK: vector.transfer_read {{.*}} %[[ALL_TRUE_MASK]] +// CHECK: vector.transfer_write {{.*}} %[[ALL_TRUE_MASK]] +func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor<1x1000xf32>) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1000 = arith.constant 1000 : index + %c0_f32 = arith.constant 0.0 : f32 + %vscale = vector.vscale + %c4_vscale = arith.muli %vscale, %c4 : index + %extracted_slice_0 = tensor.extract_slice %tensor[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x1000xf32> to tensor<1x?xf32> + %output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> { + // 1. Extract a slice. + %extracted_slice_1 = tensor.extract_slice %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor + + // 2. Create a mask for the slice. + %dim_1 = tensor.dim %extracted_slice_1, %c0 : tensor + %mask = vector.create_mask %dim_1 : vector<[4]xi1> + + // 3. Read the slice and do some computation. + %vec = vector.transfer_read %extracted_slice_1[%c0], %c0_f32, %mask {in_bounds = [true]} : tensor, vector<[4]xf32> + %new_vec = "test.some_computation"(%vec) : (vector<[4]xf32>) -> (vector<[4]xf32>) + + // 4. Write the new value. + %write = vector.transfer_write %new_vec, %extracted_slice_1[%c0], %mask {in_bounds = [true]} : vector<[4]xf32>, tensor + + // 5. Insert and yield the new tensor value. + %result = tensor.insert_slice %write into %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor into tensor<1x?xf32> + scf.yield %result : tensor<1x?xf32> + } + "test.some_use"(%output_tensor) : (tensor<1x?xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: @negative_extract_slice_size_shrink +// CHECK-NOT: vector.constant_mask +// CHECK: %[[MASK:.*]] = vector.create_mask +// CHECK: "test.some_use"(%[[MASK]]) : (vector<[4]xi1>) -> () +func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1000 = arith.constant 1000 : index + %vscale = vector.vscale + %c4_vscale = arith.muli %vscale, %c4 : index + %extracted_slice = tensor.extract_slice %tensor[0] [%c4_vscale] [1] : tensor<1000xf32> to tensor + %slice = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice) -> tensor { + // This mask cannot be eliminated even though looking at the operations above + // (this comment) it appears `tensor.dim` will always be c4_vscale (so the mask all-true). + %dim = tensor.dim %arg, %c0 : tensor + %mask = vector.create_mask %dim : vector<[4]xi1> + "test.some_use"(%mask) : (vector<[4]xi1>) -> () + // !!! Here the size of the mask could shrink in the next iteration. + %next_num_elts = affine.min affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale] + %new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_elts] [1] : tensor<1000xf32> to tensor + scf.yield %new_extracted_slice : tensor + } + "test.some_use"(%slice) : (tensor) -> () + return +} + +// ----- + +// CHECK-LABEL: @trivially_all_true_case +// CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [2, 4] : vector<2x[4]xi1> +// CHECK: "test.some_use"(%[[ALL_TRUE_MASK]]) : (vector<2x[4]xi1>) -> () +func.func @trivially_all_true_case(%tensor: tensor<2x?xf32>) +{ + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %vscale = vector.vscale + %c4_vscale = arith.muli %vscale, %c4 : index + // Is found to be all true _without_ value bounds analysis. + %mask = vector.create_mask %c2, %c4_vscale : vector<2x[4]xi1> + "test.some_use"(%mask) : (vector<2x[4]xi1>) -> () + return +} + +// ----- + +// CHECK-LABEL: @negative_constant_dim_not_all_true +// CHECK-NOT: vector.constant_mask +// CHECK: %[[MASK:.*]] = vector.create_mask +// CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> () +func.func @negative_constant_dim_not_all_true() +{ + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %vscale = vector.vscale + %c4_vscale = arith.muli %vscale, %c4 : index + // Since %c1 is a constant, this will be found not to be all-true via simple + // pattern matching. + %mask = vector.create_mask %c1, %c4_vscale : vector<2x[4]xi1> + "test.some_use"(%mask) : (vector<2x[4]xi1>) -> () + return +} + +// ----- + +// CHECK-LABEL: @negative_constant_vscale_multiple_not_all_true +// CHECK-NOT: vector.constant_mask +// CHECK: %[[MASK:.*]] = vector.create_mask +// CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> () +func.func @negative_constant_vscale_multiple_not_all_true() { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %vscale = vector.vscale + %c3_vscale = arith.muli %vscale, %c3 : index + // Since %c3_vscale is a constant vscale multiple, this will be found not to + // be all-true via simple pattern matching. + %mask = vector.create_mask %c2, %c3_vscale : vector<2x[4]xi1> + "test.some_use"(%mask) : (vector<2x[4]xi1>) -> () + return +} + +// ----- + +// CHECK-LABEL: @negative_value_bounds_fixed_dim_not_all_true +// CHECK-NOT: vector.constant_mask +// CHECK: %[[MASK:.*]] = vector.create_mask +// CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> () +func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>) +{ + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %vscale = vector.vscale + %c4_vscale = arith.muli %vscale, %c4 : index + // This is _very_ simple, but since tensor.dim is not a constant, value bounds + // will be used to resolve it. + %dim = tensor.dim %tensor, %c0 : tensor<2x?xf32> + %mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1> + "test.some_use"(%mask) : (vector<3x[4]xi1>) -> () + return +} + +// ----- + +// CHECK-LABEL: @negative_value_bounds_scalable_dim_not_all_true +// CHECK-NOT: vector.constant_mask +// CHECK: %[[MASK:.*]] = vector.create_mask +// CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> () +func.func @negative_value_bounds_scalable_dim_not_all_true(%tensor: tensor<2x100xf32>) { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %vscale = vector.vscale + %c3_vscale = arith.muli %vscale, %c3 : index + %slice = tensor.extract_slice %tensor[0, 0] [2, %c3_vscale] [1, 1] : tensor<2x100xf32> to tensor<2x?xf32> + // Another simple example, but value bounds will be used to resolve the tensor.dim. + %dim = tensor.dim %slice, %c1 : tensor<2x?xf32> + %mask = vector.create_mask %c3, %dim : vector<3x[4]xi1> + "test.some_use"(%mask) : (vector<3x[4]xi1>) -> () + return +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 592e24af94d677..29c763b622e877 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -873,6 +873,33 @@ struct TestVectorLinearize final return signalPassFailure(); } }; + +struct TestEliminateVectorMasks + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks) + + TestEliminateVectorMasks() = default; + TestEliminateVectorMasks(const TestEliminateVectorMasks &pass) + : PassWrapper(pass) {} + + Option vscaleMin{ + *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."), + llvm::cl::init(1)}; + Option vscaleMax{ + *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."), + llvm::cl::init(16)}; + + StringRef getArgument() const final { return "test-eliminate-vector-masks"; } + StringRef getDescription() const final { + return "Test eliminating vector masks"; + } + void runOnOperation() override { + IRRewriter rewriter(&getContext()); + eliminateVectorMasks(rewriter, getOperation(), + VscaleRange{vscaleMin, vscaleMax}); + } +}; } // namespace namespace mlir { @@ -919,6 +946,8 @@ void registerTestVectorLowerings() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir