Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add mask elimination transform #99314

Merged
merged 7 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail

/// Predefined constant_mask kinds.
enum class ConstantMaskKind { AllFalse = 0, AllTrue };

/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void buildTerminatedBody(OpBuilder &builder, Location loc);
Expand Down Expand Up @@ -163,6 +166,11 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
SmallVector<arith::ConstantIndexOp>
getAsConstantIndexOps(ArrayRef<Value> values);

/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
/// `std::nullopt`.
std::optional<int64_t> getConstantVscaleMultiplier(Value value);

//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2362,6 +2362,11 @@ def Vector_ConstantMaskOp :
```
}];

let builders = [
// Build with mixed static/dynamic operands.
OpBuilder<(ins "VectorType":$type, "ConstantMaskKind":$kind)>
];

let extraClassDeclaration = [{
/// Return the result type of this op.
VectorType getVectorType() {
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace mlir {
class MLIRContext;
Expand Down Expand Up @@ -115,6 +116,22 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
RewriterBase &rewriter);

// Structure to hold the range of `vector.vscale`.
struct VscaleRange {
unsigned vscaleMin;
unsigned vscaleMax;
};

/// Attempts to eliminate redundant vector masks by replacing them with all-true
/// constants at the top of the function (which results in the masks folding
/// away). Note: Currently, this only runs for vector.create_mask ops and
/// requires `vscaleRange`. If `vscaleRange` is not provided this transform does
/// nothing. This is because these redundant masks are much more likely for
/// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size
/// code has static sizes, so simpler folds remove the masks.
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
std::optional<VscaleRange> vscaleRange = {});

} // namespace vector
} // namespace mlir

Expand Down
125 changes: 64 additions & 61 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5749,6 +5749,16 @@ void vector::TransposeOp::getCanonicalizationPatterns(
// ConstantMaskOp
//===----------------------------------------------------------------------===//

void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
VectorType type, ConstantMaskKind kind) {
assert(kind == ConstantMaskKind::AllTrue ||
kind == ConstantMaskKind::AllFalse);
build(builder, result, type,
kind == ConstantMaskKind::AllTrue
? type.getShape()
: SmallVector<int64_t>(type.getRank(), 0));
}

LogicalResult ConstantMaskOp::verify() {
auto resultType = llvm::cast<VectorType>(getResult().getType());
// Check the corner case of 0-D vectors first.
Expand Down Expand Up @@ -5831,6 +5841,21 @@ LogicalResult CreateMaskOp::verify() {
return success();
}

std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
if (value.getDefiningOp<vector::VectorScaleOp>())
return 1;
auto mul = value.getDefiningOp<arith::MulIOp>();
if (!mul)
return {};
auto lhs = mul.getLhs();
auto rhs = mul.getRhs();
if (lhs.getDefiningOp<vector::VectorScaleOp>())
return getConstantIntValue(rhs);
if (rhs.getDefiningOp<vector::VectorScaleOp>())
return getConstantIntValue(lhs);
return {};
}

namespace {

/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
Expand Down Expand Up @@ -5862,73 +5887,51 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {

LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
VectorType retTy = createMaskOp.getResult().getType();
bool isScalable = retTy.isScalable();

// Check every mask operand
for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto cst = getConstantIntValue(operand)) {
// Most basic case - this operand is a constant value. Note that for
// scalable dimensions, CreateMaskOp can be folded only if the
// corresponding operand is negative or zero.
if (retTy.getScalableDims()[opIdx] && *cst > 0)
return failure();

continue;
}

// Non-constant operands are not allowed for non-scalable vectors.
if (!isScalable)
return failure();

// For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
// true" mask, so can also be treated as constant.
auto mul = operand.getDefiningOp<arith::MulIOp>();
if (!mul)
return failure();
auto mulLHS = mul.getRhs();
auto mulRHS = mul.getLhs();
bool isOneOpVscale =
(isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));

auto isConstantValMatchingDim =
[=, dim = retTy.getShape()[opIdx]](Value operand) {
auto constantVal = getConstantIntValue(operand);
return (constantVal.has_value() && constantVal.value() == dim);
};

bool isOneOpConstantMatchingDim =
isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);

if (!isOneOpVscale || !isOneOpConstantMatchingDim)
return failure();
VectorType maskType = createMaskOp.getVectorType();
ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();

// Special case: Rank zero shape.
constexpr std::array<int64_t, 1> rankZeroShape{1};
constexpr std::array<bool, 1> rankZeroScalableDims{false};
if (maskType.getRank() == 0) {
maskTypeDimSizes = rankZeroShape;
maskTypeDimScalableFlags = rankZeroScalableDims;
}

// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
maskDimSizes.reserve(createMaskOp->getNumOperands());
for (auto [operand, maxDimSize] : llvm::zip_equal(
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
std::optional dimSize = getConstantIntValue(operand);
if (!dimSize) {
// Although not a constant, it is safe to assume that `operand` is
// "vscale * maxDimSize".
maskDimSizes.push_back(maxDimSize);
continue;
}
int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
// If one of dim sizes is zero, set all dims to zero.
if (dimSize <= 0) {
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
break;
// Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
// collect the `constantDims` (for the ConstantMaskOp).
SmallVector<int64_t, 4> constantDims;
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto intSize = getConstantIntValue(dimSize)) {
// Constant value.
// If the mask dim is non-scalable this can be any value.
// If the mask dim is scalable only zero (all-false) is supported.
if (maskTypeDimScalableFlags[i] && intSize >= 0)
return failure();
constantDims.push_back(*intSize);
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
// Constant vscale multiple (e.g. 4 x vscale).
// Must be all-true to fold to a ConstantMask.
if (vscaleMultiplier < maskTypeDimSizes[i])
return failure();
constantDims.push_back(*vscaleMultiplier);
} else {
return failure();
}
maskDimSizes.push_back(dimSizeVal);
}

// Clamp values to constant_mask bounds.
for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
value = std::clamp<int64_t>(value, 0, maskDimSize);

// If one of dim sizes is zero, set all dims to zero.
if (llvm::is_contained(constantDims, 0))
constantDims.assign(constantDims.size(), 0);

// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
maskDimSizes);
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
constantDims);
return success();
}
};
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
VectorUnroll.cpp
VectorMaskElimination.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
Expand Down
118 changes: 118 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
MacDue marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

using namespace mlir;
using namespace mlir::vector;
namespace {

/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
/// All-true masks can then be eliminated by simple folds.
LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
vector::CreateMaskOp createMaskOp,
VscaleRange vscaleRange) {
auto maskType = createMaskOp.getVectorType();
auto maskTypeDimScalableFlags = maskType.getScalableDims();
auto maskTypeDimSizes = maskType.getShape();

struct UnknownMaskDim {
size_t position;
Value dimSize;
};

// Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
// that are not obviously constant). If any constant dimension is not all-true
// bail out early (as this transform only trying to resolve all-true masks).
// This avoids doing value-bounds anaylis in cases like:
// `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
// ...where it is known the mask is not all-true by looking at `%c2`.
SmallVector<UnknownMaskDim> unknownDims;
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto intSize = getConstantIntValue(dimSize)) {
// Mask not all-true for this dim.
if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
return failure();
banach-space marked this conversation as resolved.
Show resolved Hide resolved
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
// Mask not all-true for this dim.
if (vscaleMultiplier < maskTypeDimSizes[i])
return failure();
} else {
// Unknown (without further analysis).
unknownDims.push_back(UnknownMaskDim{i, dimSize});
}
}

for (auto [i, dimSize] : unknownDims) {
// Compute the lower bound for the unknown dimension (i.e. the smallest
// value it could be).
FailureOr<ConstantOrScalableBound> dimLowerBound =
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
presburger::BoundType::LB);
if (failed(dimLowerBound))
return failure();
c-rhodes marked this conversation as resolved.
Show resolved Hide resolved
auto dimLowerBoundSize = dimLowerBound->getSize();
if (failed(dimLowerBoundSize))
return failure();
if (dimLowerBoundSize->scalable) {
// 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
// this dim is not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
} else {
// 2. The lower bound, LB, is a constant.
// - If the mask dim size is scalable then this dim is not all-true.
if (maskTypeDimScalableFlags[i])
return failure();
// - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
}
}

// Replace createMaskOp with an all-true constant. This should result in the
// mask being removed in most cases (as xfer ops + vector.mask have folds to
// remove all-true masks).
auto allTrue = rewriter.create<vector::ConstantMaskOp>(
createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
rewriter.replaceAllUsesWith(createMaskOp, allTrue);
return success();
}

} // namespace

namespace mlir::vector {

void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
std::optional<VscaleRange> vscaleRange) {
// TODO: Support fixed-size case. This is less likely to be useful as for
// fixed-size code dimensions are all static so masks tend to fold away.
if (!vscaleRange)
return;

OpBuilder::InsertionGuard g(rewriter);

// Build worklist so we can safely insert new ops in
// `resolveAllTrueCreateMaskOp()`.
SmallVector<vector::CreateMaskOp> worklist;
function.walk([&](vector::CreateMaskOp createMaskOp) {
worklist.push_back(createMaskOp);
});

rewriter.setInsertionPointToStart(&function.front());
for (auto mask : worklist)
(void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
}

} // namespace mlir::vector
Loading
Loading