Skip to content

Commit

Permalink
[mlir][vector] Use DenseI64ArrayAttr for constant_mask dim sizes (l…
Browse files Browse the repository at this point in the history
…lvm#100997)

This prevents a bunch of boilerplate conversions to/from IntegerAttrs
and int64_ts. Other than that this is a NFC.
  • Loading branch information
MacDue authored and banach-space committed Aug 7, 2024
1 parent a914a7e commit be84538
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 45 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :

def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [Pure]>,
Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
let description = [{
Expand Down
44 changes: 17 additions & 27 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
// Inspect constant mask index. If the index exceeds the
// dimension size, all bits are set. If the index is zero
// or less, no bits are set.
ArrayAttr masks = m.getMaskDimSizes();
ArrayRef<int64_t> masks = m.getMaskDimSizes();
auto shape = m.getType().getShape();
bool allTrue = true;
bool allFalse = true;
for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
if (i < dimSize)
if (maskIdx < dimSize)
allTrue = false;
if (i > 0)
if (maskIdx > 0)
allFalse = false;
}
if (allTrue)
Expand Down Expand Up @@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
// Gather strided slice offsets and sizes.
SmallVector<int64_t, 4> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
Expand Down Expand Up @@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
// region.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
sliceMaskDimSizes);
return success();
}
};
Expand Down Expand Up @@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}

if (constantMaskOp) {
auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
auto numMaskOperands = maskDimSizes.size();

// Check every mask dim size to see whether it can be dropped
for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
--i) {
if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
if (maskDimSizes[i] != 1)
return failure();
}

auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);

rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
newMaskOperandsAttr);
newMaskOperands);
return success();
}

Expand Down Expand Up @@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {

// ConstantMaskOp case.
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
applyPermutationToVector(newMaskDimSizes, permutation);
auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);

rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
transpOp, transpOp.getResultVectorType(),
ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
return success();
}
};
Expand All @@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
if (resultType.getRank() == 0) {
if (getMaskDimSizes().size() != 1)
return emitError("array attr must have length 1 for 0-D vectors");
auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
auto dim = getMaskDimSizes()[0];
if (dim != 0 && dim != 1)
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
return success();
Expand All @@ -5846,17 +5840,15 @@ LogicalResult ConstantMaskOp::verify() {
// result dimension size.
auto resultShape = resultType.getShape();
auto resultScalableDims = resultType.getScalableDims();
SmallVector<int64_t, 4> maskDimSizes;
for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
if (maskDimSize < 0 || maskDimSize > resultShape[index])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
if (resultScalableDims[index] && maskDimSize != 0 &&
maskDimSize != resultShape[index])
return emitOpError(
"only supports 'none set' or 'all set' scalable dimensions");
maskDimSizes.push_back(maskDimSize);
}
// Verify that if one mask dim size is zero, they all should be zero (because
// the mask region is a conjunction of each mask dimension interval).
Expand All @@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
return getMaskDimSizes()[0] == 1;
}
for (const auto [resultSize, intAttr] :
for (const auto [resultSize, maskDimSize] :
llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
if (maskDimSize < resultSize)
return false;
}
Expand Down Expand Up @@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
}

// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, retTy,
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
maskDimSizes);
return success();
}
};
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
bool value = dimSizes.front() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
value));
return success();
}

int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
int64_t trueDimSize = dimSizes.front();

if (rank == 1) {
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
Expand Down Expand Up @@ -147,7 +147,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {

VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
loc, lowType, dimSizes.drop_front());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,7 @@ struct CastAwayConstantMaskLeadingOneDim
return failure();

int64_t dropDim = oldType.getRank() - newType.getRank();
SmallVector<int64_t> dimSizes;
for (auto attr : mask.getMaskDimSizes())
dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();

// If any of the dropped unit dims has a size of `0`, the entire mask is a
// zero mask, else the unit dim has no effect on the mask.
Expand All @@ -563,7 +561,7 @@ struct CastAwayConstantMaskLeadingOneDim
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());

auto newMask = rewriter.create<vector::ConstantMaskOp>(
mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
mask.getLoc(), newType, newDimSizes);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
return success();
}
Expand Down
17 changes: 7 additions & 10 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
} else if (constantMaskOp) {
ArrayRef<Attribute> maskDimSizes =
constantMaskOp.getMaskDimSizes().getValue();
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
auto origIndex =
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
IntegerAttr maskIndexAttr =
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndexAttr);
newMask = rewriter.create<vector::ConstantMaskOp>(
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
int64_t maskIndex = (origIndex + scale - 1) / scale;
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndex);
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
newMaskDimSizes);
}

while (!extractOps.empty()) {
Expand Down

0 comments on commit be84538

Please sign in to comment.