Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Use DenseI64ArrayAttr for constant_mask dim sizes #100997

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :

def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [Pure]>,
Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
let description = [{
Expand Down
44 changes: 17 additions & 27 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
// Inspect constant mask index. If the index exceeds the
// dimension size, all bits are set. If the index is zero
// or less, no bits are set.
ArrayAttr masks = m.getMaskDimSizes();
ArrayRef<int64_t> masks = m.getMaskDimSizes();
auto shape = m.getType().getShape();
bool allTrue = true;
bool allFalse = true;
for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
if (i < dimSize)
if (maskIdx < dimSize)
allTrue = false;
if (i > 0)
if (maskIdx > 0)
allFalse = false;
}
if (allTrue)
Expand Down Expand Up @@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
// Gather strided slice offsets and sizes.
SmallVector<int64_t, 4> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
Expand Down Expand Up @@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
// region.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
sliceMaskDimSizes);
return success();
}
};
Expand Down Expand Up @@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}

if (constantMaskOp) {
auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
auto numMaskOperands = maskDimSizes.size();

// Check every mask dim size to see whether it can be dropped
for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
--i) {
if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
if (maskDimSizes[i] != 1)
return failure();
}

auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);

rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
newMaskOperandsAttr);
newMaskOperands);
return success();
}

Expand Down Expand Up @@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {

// ConstantMaskOp case.
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
applyPermutationToVector(newMaskDimSizes, permutation);
auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);

rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
transpOp, transpOp.getResultVectorType(),
ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
return success();
}
};
Expand All @@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
if (resultType.getRank() == 0) {
if (getMaskDimSizes().size() != 1)
return emitError("array attr must have length 1 for 0-D vectors");
auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
auto dim = getMaskDimSizes()[0];
if (dim != 0 && dim != 1)
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
return success();
Expand All @@ -5846,17 +5840,15 @@ LogicalResult ConstantMaskOp::verify() {
// result dimension size.
auto resultShape = resultType.getShape();
auto resultScalableDims = resultType.getScalableDims();
SmallVector<int64_t, 4> maskDimSizes;
for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
if (maskDimSize < 0 || maskDimSize > resultShape[index])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
if (resultScalableDims[index] && maskDimSize != 0 &&
maskDimSize != resultShape[index])
return emitOpError(
"only supports 'none set' or 'all set' scalable dimensions");
maskDimSizes.push_back(maskDimSize);
}
// Verify that if one mask dim size is zero, they all should be zero (because
// the mask region is a conjunction of each mask dimension interval).
Expand All @@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
return getMaskDimSizes()[0] == 1;
}
for (const auto [resultSize, intAttr] :
for (const auto [resultSize, maskDimSize] :
llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
if (maskDimSize < resultSize)
return false;
}
Expand Down Expand Up @@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
}

// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, retTy,
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
maskDimSizes);
return success();
}
};
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
bool value = dimSizes.front() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
value));
return success();
}

int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
int64_t trueDimSize = dimSizes.front();

if (rank == 1) {
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
Expand Down Expand Up @@ -147,7 +147,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {

VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
loc, lowType, dimSizes.drop_front());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,7 @@ struct CastAwayConstantMaskLeadingOneDim
return failure();

int64_t dropDim = oldType.getRank() - newType.getRank();
SmallVector<int64_t> dimSizes;
for (auto attr : mask.getMaskDimSizes())
dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();

// If any of the dropped unit dims has a size of `0`, the entire mask is a
// zero mask, else the unit dim has no effect on the mask.
Expand All @@ -563,7 +561,7 @@ struct CastAwayConstantMaskLeadingOneDim
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());

auto newMask = rewriter.create<vector::ConstantMaskOp>(
mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
mask.getLoc(), newType, newDimSizes);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
return success();
}
Expand Down
17 changes: 7 additions & 10 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
} else if (constantMaskOp) {
ArrayRef<Attribute> maskDimSizes =
constantMaskOp.getMaskDimSizes().getValue();
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
auto origIndex =
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
IntegerAttr maskIndexAttr =
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndexAttr);
newMask = rewriter.create<vector::ConstantMaskOp>(
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
int64_t maskIndex = (origIndex + scale - 1) / scale;
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndex);
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
newMaskDimSizes);
}

while (!extractOps.empty()) {
Expand Down
Loading