Skip to content

Commit

Permalink
[mlir][vector] Add mask elimination transform (#99314)
Browse files Browse the repository at this point in the history
This adds a new transform `eliminateVectorMasks()` which aims at
removing scalable `vector.create_masks` that will be all-true at
runtime. It attempts to do this by simply pattern-matching the mask
operands (similar to some canonicalizations), if that does not lead to
an answer (is all-true? yes/no), then value bounds analysis will be used
to find the lower bound of the unknown operands. If the lower bound is
>= to the corresponding mask vector type dim, then that dimension of the
mask is all true.

Note that the pattern matching prevents expensive value-bounds analysis
in cases where the mask won't be all true.

For example:
```mlir
%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>
```
From looking at `%c2` we can tell this is not going to be an all-true
mask, so we don't need to run the value-bounds analysis for
`%dynamicValue` (and can exit the transform early).

Note: Eliminating create_masks here means replacing them with all-true
constants (which will then lead to the masks folding away).
  • Loading branch information
MacDue authored Aug 9, 2024
1 parent badfb4b commit 9b06e25
Show file tree
Hide file tree
Showing 8 changed files with 401 additions and 61 deletions.
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail

/// Predefined constant_mask kinds.
enum class ConstantMaskKind { AllFalse = 0, AllTrue };

/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void buildTerminatedBody(OpBuilder &builder, Location loc);
Expand Down Expand Up @@ -168,6 +171,11 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
SmallVector<arith::ConstantIndexOp>
getAsConstantIndexOps(ArrayRef<Value> values);

/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
/// `std::nullopt`.
std::optional<int64_t> getConstantVscaleMultiplier(Value value);

//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,11 @@ def Vector_ConstantMaskOp :
```
}];

let builders = [
// Build with mixed static/dynamic operands.
OpBuilder<(ins "VectorType":$type, "ConstantMaskKind":$kind)>
];

let extraClassDeclaration = [{
/// Return the result type of this op.
VectorType getVectorType() {
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace mlir {
class MLIRContext;
Expand Down Expand Up @@ -115,6 +116,22 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
RewriterBase &rewriter);

// Structure to hold the range of `vector.vscale`.
struct VscaleRange {
unsigned vscaleMin;
unsigned vscaleMax;
};

/// Attempts to eliminate redundant vector masks by replacing them with all-true
/// constants at the top of the function (which results in the masks folding
/// away). Note: Currently, this only runs for vector.create_mask ops and
/// requires `vscaleRange`. If `vscaleRange` is not provided this transform does
/// nothing. This is because these redundant masks are much more likely for
/// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size
/// code has static sizes, so simpler folds remove the masks.
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
std::optional<VscaleRange> vscaleRange = {});

} // namespace vector
} // namespace mlir

Expand Down
125 changes: 64 additions & 61 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5776,6 +5776,16 @@ void vector::TransposeOp::getCanonicalizationPatterns(
// ConstantMaskOp
//===----------------------------------------------------------------------===//

void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
VectorType type, ConstantMaskKind kind) {
assert(kind == ConstantMaskKind::AllTrue ||
kind == ConstantMaskKind::AllFalse);
build(builder, result, type,
kind == ConstantMaskKind::AllTrue
? type.getShape()
: SmallVector<int64_t>(type.getRank(), 0));
}

LogicalResult ConstantMaskOp::verify() {
auto resultType = llvm::cast<VectorType>(getResult().getType());
// Check the corner case of 0-D vectors first.
Expand Down Expand Up @@ -5858,6 +5868,21 @@ LogicalResult CreateMaskOp::verify() {
return success();
}

std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
if (value.getDefiningOp<vector::VectorScaleOp>())
return 1;
auto mul = value.getDefiningOp<arith::MulIOp>();
if (!mul)
return {};
auto lhs = mul.getLhs();
auto rhs = mul.getRhs();
if (lhs.getDefiningOp<vector::VectorScaleOp>())
return getConstantIntValue(rhs);
if (rhs.getDefiningOp<vector::VectorScaleOp>())
return getConstantIntValue(lhs);
return {};
}

namespace {

/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
Expand Down Expand Up @@ -5889,73 +5914,51 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {

LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
VectorType retTy = createMaskOp.getResult().getType();
bool isScalable = retTy.isScalable();

// Check every mask operand
for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto cst = getConstantIntValue(operand)) {
// Most basic case - this operand is a constant value. Note that for
// scalable dimensions, CreateMaskOp can be folded only if the
// corresponding operand is negative or zero.
if (retTy.getScalableDims()[opIdx] && *cst > 0)
return failure();

continue;
}

// Non-constant operands are not allowed for non-scalable vectors.
if (!isScalable)
return failure();

// For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
// true" mask, so can also be treated as constant.
auto mul = operand.getDefiningOp<arith::MulIOp>();
if (!mul)
return failure();
auto mulLHS = mul.getRhs();
auto mulRHS = mul.getLhs();
bool isOneOpVscale =
(isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));

auto isConstantValMatchingDim =
[=, dim = retTy.getShape()[opIdx]](Value operand) {
auto constantVal = getConstantIntValue(operand);
return (constantVal.has_value() && constantVal.value() == dim);
};

bool isOneOpConstantMatchingDim =
isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);

if (!isOneOpVscale || !isOneOpConstantMatchingDim)
return failure();
VectorType maskType = createMaskOp.getVectorType();
ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();

// Special case: Rank zero shape.
constexpr std::array<int64_t, 1> rankZeroShape{1};
constexpr std::array<bool, 1> rankZeroScalableDims{false};
if (maskType.getRank() == 0) {
maskTypeDimSizes = rankZeroShape;
maskTypeDimScalableFlags = rankZeroScalableDims;
}

// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
maskDimSizes.reserve(createMaskOp->getNumOperands());
for (auto [operand, maxDimSize] : llvm::zip_equal(
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
std::optional dimSize = getConstantIntValue(operand);
if (!dimSize) {
// Although not a constant, it is safe to assume that `operand` is
// "vscale * maxDimSize".
maskDimSizes.push_back(maxDimSize);
continue;
}
int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
// If one of dim sizes is zero, set all dims to zero.
if (dimSize <= 0) {
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
break;
// Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
// collect the `constantDims` (for the ConstantMaskOp).
SmallVector<int64_t, 4> constantDims;
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto intSize = getConstantIntValue(dimSize)) {
// Constant value.
// If the mask dim is non-scalable this can be any value.
// If the mask dim is scalable only zero (all-false) is supported.
if (maskTypeDimScalableFlags[i] && intSize >= 0)
return failure();
constantDims.push_back(*intSize);
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
// Constant vscale multiple (e.g. 4 x vscale).
// Must be all-true to fold to a ConstantMask.
if (vscaleMultiplier < maskTypeDimSizes[i])
return failure();
constantDims.push_back(*vscaleMultiplier);
} else {
return failure();
}
maskDimSizes.push_back(dimSizeVal);
}

// Clamp values to constant_mask bounds.
for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
value = std::clamp<int64_t>(value, 0, maskDimSize);

// If one of dim sizes is zero, set all dims to zero.
if (llvm::is_contained(constantDims, 0))
constantDims.assign(constantDims.size(), 0);

// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
maskDimSizes);
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
constantDims);
return success();
}
};
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
VectorUnroll.cpp
VectorMaskElimination.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
Expand Down
118 changes: 118 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

using namespace mlir;
using namespace mlir::vector;
namespace {

/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
/// All-true masks can then be eliminated by simple folds.
LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
vector::CreateMaskOp createMaskOp,
VscaleRange vscaleRange) {
auto maskType = createMaskOp.getVectorType();
auto maskTypeDimScalableFlags = maskType.getScalableDims();
auto maskTypeDimSizes = maskType.getShape();

struct UnknownMaskDim {
size_t position;
Value dimSize;
};

// Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
// that are not obviously constant). If any constant dimension is not all-true
// bail out early (as this transform only trying to resolve all-true masks).
// This avoids doing value-bounds anaylis in cases like:
// `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
// ...where it is known the mask is not all-true by looking at `%c2`.
SmallVector<UnknownMaskDim> unknownDims;
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto intSize = getConstantIntValue(dimSize)) {
// Mask not all-true for this dim.
if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
return failure();
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
// Mask not all-true for this dim.
if (vscaleMultiplier < maskTypeDimSizes[i])
return failure();
} else {
// Unknown (without further analysis).
unknownDims.push_back(UnknownMaskDim{i, dimSize});
}
}

for (auto [i, dimSize] : unknownDims) {
// Compute the lower bound for the unknown dimension (i.e. the smallest
// value it could be).
FailureOr<ConstantOrScalableBound> dimLowerBound =
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
presburger::BoundType::LB);
if (failed(dimLowerBound))
return failure();
auto dimLowerBoundSize = dimLowerBound->getSize();
if (failed(dimLowerBoundSize))
return failure();
if (dimLowerBoundSize->scalable) {
// 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
// this dim is not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
} else {
// 2. The lower bound, LB, is a constant.
// - If the mask dim size is scalable then this dim is not all-true.
if (maskTypeDimScalableFlags[i])
return failure();
// - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
}
}

// Replace createMaskOp with an all-true constant. This should result in the
// mask being removed in most cases (as xfer ops + vector.mask have folds to
// remove all-true masks).
auto allTrue = rewriter.create<vector::ConstantMaskOp>(
createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
rewriter.replaceAllUsesWith(createMaskOp, allTrue);
return success();
}

} // namespace

namespace mlir::vector {

void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
std::optional<VscaleRange> vscaleRange) {
// TODO: Support fixed-size case. This is less likely to be useful as for
// fixed-size code dimensions are all static so masks tend to fold away.
if (!vscaleRange)
return;

OpBuilder::InsertionGuard g(rewriter);

// Build worklist so we can safely insert new ops in
// `resolveAllTrueCreateMaskOp()`.
SmallVector<vector::CreateMaskOp> worklist;
function.walk([&](vector::CreateMaskOp createMaskOp) {
worklist.push_back(createMaskOp);
});

rewriter.setInsertionPointToStart(&function.front());
for (auto mask : worklist)
(void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
}

} // namespace mlir::vector
Loading

0 comments on commit 9b06e25

Please sign in to comment.