From be84538b346d834a57590c8191428b2e23c8b395 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Mon, 29 Jul 2024 18:08:37 +0100 Subject: [PATCH] [mlir][vector] Use `DenseI64ArrayAttr` for constant_mask dim sizes (#100997) This prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC. --- .../mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 44 +++++++------------ .../Vector/Transforms/LowerVectorMask.cpp | 6 +-- .../Transforms/VectorDropLeadUnitDim.cpp | 6 +-- .../Transforms/VectorEmulateNarrowType.cpp | 17 +++---- 5 files changed, 30 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 39ad03c8011409a..3cdbd218745675f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2443,7 +2443,7 @@ def Vector_TypeCastOp : def Vector_ConstantMaskOp : Vector_Op<"constant_mask", [Pure]>, - Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>, + Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>, Results<(outs VectorOfAnyRankOf<[I1]>)> { let summary = "creates a constant vector mask"; let description = [{ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d297c40760cd83e..669ae586e578612 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) { // Inspect constant mask index. If the index exceeds the // dimension size, all bits are set. If the index is zero // or less, no bits are set. - ArrayAttr masks = m.getMaskDimSizes(); + ArrayRef masks = m.getMaskDimSizes(); auto shape = m.getType().getShape(); bool allTrue = true; bool allFalse = true; for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) { - int64_t i = llvm::cast(maskIdx).getInt(); - if (i < dimSize) + if (maskIdx < dimSize) allTrue = false; - if (i > 0) + if (maskIdx > 0) allFalse = false; } if (allTrue) @@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final if (extractStridedSliceOp.hasNonUnitStrides()) return failure(); // Gather constant mask dimension sizes. - SmallVector maskDimSizes; - populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes); + ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); // Gather strided slice offsets and sizes. SmallVector sliceOffsets; populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), @@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final // region. rewriter.replaceOpWithNewOp( extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), - vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); + sliceMaskDimSizes); return success(); } }; @@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final } if (constantMaskOp) { - auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue(); + auto maskDimSizes = constantMaskOp.getMaskDimSizes(); auto numMaskOperands = maskDimSizes.size(); // Check every mask dim size to see whether it can be dropped for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop; --i) { - if (cast(maskDimSizes[i]).getValue() != 1) + if (maskDimSizes[i] != 1) return failure(); } auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop); - ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands); - rewriter.replaceOpWithNewOp(shapeOp, shapeOpResTy, - newMaskOperandsAttr); + newMaskOperands); return success(); } @@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern { // ConstantMaskOp case. auto maskDimSizes = constantMaskOp.getMaskDimSizes(); - SmallVector newMaskDimSizes(maskDimSizes.getValue()); - applyPermutationToVector(newMaskDimSizes, permutation); + auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation); rewriter.replaceOpWithNewOp( - transpOp, transpOp.getResultVectorType(), - ArrayAttr::get(transpOp.getContext(), newMaskDimSizes)); + transpOp, transpOp.getResultVectorType(), newMaskDimSizes); return success(); } }; @@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() { if (resultType.getRank() == 0) { if (getMaskDimSizes().size() != 1) return emitError("array attr must have length 1 for 0-D vectors"); - auto dim = llvm::cast(getMaskDimSizes()[0]).getInt(); + auto dim = getMaskDimSizes()[0]; if (dim != 0 && dim != 1) return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); return success(); @@ -5846,9 +5840,8 @@ LogicalResult ConstantMaskOp::verify() { // result dimension size. auto resultShape = resultType.getShape(); auto resultScalableDims = resultType.getScalableDims(); - SmallVector maskDimSizes; - for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) { - int64_t maskDimSize = llvm::cast(intAttr).getInt(); + ArrayRef maskDimSizes = getMaskDimSizes(); + for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) { if (maskDimSize < 0 || maskDimSize > resultShape[index]) return emitOpError( "array attr of size out of bounds of vector result dimension size"); @@ -5856,7 +5849,6 @@ LogicalResult ConstantMaskOp::verify() { maskDimSize != resultShape[index]) return emitOpError( "only supports 'none set' or 'all set' scalable dimensions"); - maskDimSizes.push_back(maskDimSize); } // Verify that if one mask dim size is zero, they all should be zero (because // the mask region is a conjunction of each mask dimension interval). @@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() { // Check the corner case of 0-D vectors first. if (resultType.getRank() == 0) { assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask"); - return llvm::cast(getMaskDimSizes()[0]).getInt() == 1; + return getMaskDimSizes()[0] == 1; } - for (const auto [resultSize, intAttr] : + for (const auto [resultSize, maskDimSize] : llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) { - int64_t maskDimSize = llvm::cast(intAttr).getInt(); if (maskDimSize < resultSize) return false; } @@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern { } // Replace 'createMaskOp' with ConstantMaskOp. - rewriter.replaceOpWithNewOp( - createMaskOp, retTy, - vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); + rewriter.replaceOpWithNewOp(createMaskOp, retTy, + maskDimSizes); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index dfeb7bc53adad7c..bfc05c71f53401f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -111,7 +111,7 @@ class ConstantMaskOpLowering : public OpRewritePattern { if (rank == 0) { assert(dimSizes.size() == 1 && "Expected exactly one dim size for a 0-D vector"); - bool value = cast(dimSizes[0]).getInt() == 1; + bool value = dimSizes.front() == 1; rewriter.replaceOpWithNewOp( op, dstType, DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()), @@ -119,7 +119,7 @@ class ConstantMaskOpLowering : public OpRewritePattern { return success(); } - int64_t trueDimSize = cast(dimSizes[0]).getInt(); + int64_t trueDimSize = dimSizes.front(); if (rank == 1) { if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { @@ -147,7 +147,7 @@ class ConstantMaskOpLowering : public OpRewritePattern { VectorType lowType = VectorType::Builder(dstType).dropDim(0); Value trueVal = rewriter.create( - loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front())); + loc, lowType, dimSizes.drop_front()); Value result = rewriter.create( loc, dstType, rewriter.getZeroAttr(dstType)); for (int64_t d = 0; d < trueDimSize; d++) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 7ed3dea42b77156..42ac717b44c4b9e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -550,9 +550,7 @@ struct CastAwayConstantMaskLeadingOneDim return failure(); int64_t dropDim = oldType.getRank() - newType.getRank(); - SmallVector dimSizes; - for (auto attr : mask.getMaskDimSizes()) - dimSizes.push_back(llvm::cast(attr).getInt()); + ArrayRef dimSizes = mask.getMaskDimSizes(); // If any of the dropped unit dims has a size of `0`, the entire mask is a // zero mask, else the unit dim has no effect on the mask. @@ -563,7 +561,7 @@ struct CastAwayConstantMaskLeadingOneDim newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); auto newMask = rewriter.create( - mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes)); + mask.getLoc(), newType, newDimSizes); rewriter.replaceOpWithNewOp(mask, oldType, newMask); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index ac2a4d3abcc68c2..d3296ee38c24969 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -83,17 +83,14 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, newMask = rewriter.create(loc, newMaskType, newMaskOperands); } else if (constantMaskOp) { - ArrayRef maskDimSizes = - constantMaskOp.getMaskDimSizes().getValue(); + ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); size_t numMaskOperands = maskDimSizes.size(); - auto origIndex = - cast(maskDimSizes[numMaskOperands - 1]).getInt(); - IntegerAttr maskIndexAttr = - rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale); - SmallVector newMaskDimSizes(maskDimSizes.drop_back()); - newMaskDimSizes.push_back(maskIndexAttr); - newMask = rewriter.create( - loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes)); + int64_t origIndex = maskDimSizes[numMaskOperands - 1]; + int64_t maskIndex = (origIndex + scale - 1) / scale; + SmallVector newMaskDimSizes(maskDimSizes.drop_back()); + newMaskDimSizes.push_back(maskIndex); + newMask = rewriter.create(loc, newMaskType, + newMaskDimSizes); } while (!extractOps.empty()) {