Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Use DenseI64ArrayAttr for constant_mask dim sizes #100997

Merged
merged 1 commit into from
Jul 29, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jul 29, 2024

This prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC.

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 29, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC.


Full diff: https://github.com/llvm/llvm-project/pull/100997.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+17-27)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+7-10)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 39ad03c801140..3cdbd21874567 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :
 
 def Vector_ConstantMaskOp :
   Vector_Op<"constant_mask", [Pure]>,
-    Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
+    Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
     Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a constant vector mask";
   let description = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..669ae586e5786 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
     // Inspect constant mask index. If the index exceeds the
     // dimension size, all bits are set. If the index is zero
     // or less, no bits are set.
-    ArrayAttr masks = m.getMaskDimSizes();
+    ArrayRef<int64_t> masks = m.getMaskDimSizes();
     auto shape = m.getType().getShape();
     bool allTrue = true;
     bool allFalse = true;
     for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
-      int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
-      if (i < dimSize)
+      if (maskIdx < dimSize)
         allTrue = false;
-      if (i > 0)
+      if (maskIdx > 0)
         allFalse = false;
     }
     if (allTrue)
@@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
     if (extractStridedSliceOp.hasNonUnitStrides())
       return failure();
     // Gather constant mask dimension sizes.
-    SmallVector<int64_t, 4> maskDimSizes;
-    populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
+    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     // Gather strided slice offsets and sizes.
     SmallVector<int64_t, 4> sliceOffsets;
     populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
@@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
     // region.
     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
         extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
-        vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
+        sliceMaskDimSizes);
     return success();
   }
 };
@@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
     }
 
     if (constantMaskOp) {
-      auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+      auto maskDimSizes = constantMaskOp.getMaskDimSizes();
       auto numMaskOperands = maskDimSizes.size();
 
       // Check every mask dim size to see whether it can be dropped
       for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
            --i) {
-        if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
+        if (maskDimSizes[i] != 1)
           return failure();
       }
 
       auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
-      ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
-
       rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
-                                                          newMaskOperandsAttr);
+                                                          newMaskOperands);
       return success();
     }
 
@@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
 
     // ConstantMaskOp case.
     auto maskDimSizes = constantMaskOp.getMaskDimSizes();
-    SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
-    applyPermutationToVector(newMaskDimSizes, permutation);
+    auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
 
     rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
-        transpOp, transpOp.getResultVectorType(),
-        ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
+        transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
     return success();
   }
 };
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
   if (resultType.getRank() == 0) {
     if (getMaskDimSizes().size() != 1)
       return emitError("array attr must have length 1 for 0-D vectors");
-    auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
+    auto dim = getMaskDimSizes()[0];
     if (dim != 0 && dim != 1)
       return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
     return success();
@@ -5846,9 +5840,8 @@ LogicalResult ConstantMaskOp::verify() {
   // result dimension size.
   auto resultShape = resultType.getShape();
   auto resultScalableDims = resultType.getScalableDims();
-  SmallVector<int64_t, 4> maskDimSizes;
-  for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
-    int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
+  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
     if (maskDimSize < 0 || maskDimSize > resultShape[index])
       return emitOpError(
           "array attr of size out of bounds of vector result dimension size");
@@ -5856,7 +5849,6 @@ LogicalResult ConstantMaskOp::verify() {
         maskDimSize != resultShape[index])
       return emitOpError(
           "only supports 'none set' or 'all set' scalable dimensions");
-    maskDimSizes.push_back(maskDimSize);
   }
   // Verify that if one mask dim size is zero, they all should be zero (because
   // the mask region is a conjunction of each mask dimension interval).
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
   // Check the corner case of 0-D vectors first.
   if (resultType.getRank() == 0) {
     assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
-    return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
+    return getMaskDimSizes()[0] == 1;
   }
-  for (const auto [resultSize, intAttr] :
+  for (const auto [resultSize, maskDimSize] :
        llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
-    int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
     if (maskDimSize < resultSize)
       return false;
   }
@@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
     }
 
     // Replace 'createMaskOp' with ConstantMaskOp.
-    rewriter.replaceOpWithNewOp<ConstantMaskOp>(
-        createMaskOp, retTy,
-        vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
+    rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
+                                                maskDimSizes);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index dfeb7bc53adad..bfc05c71f5340 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -111,7 +111,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
     if (rank == 0) {
       assert(dimSizes.size() == 1 &&
              "Expected exactly one dim size for a 0-D vector");
-      bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
+      bool value = dimSizes.front() == 1;
       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
           op, dstType,
           DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
@@ -119,7 +119,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
       return success();
     }
 
-    int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
+    int64_t trueDimSize = dimSizes.front();
 
     if (rank == 1) {
       if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
@@ -147,7 +147,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
 
     VectorType lowType = VectorType::Builder(dstType).dropDim(0);
     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
-        loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
+        loc, lowType, dimSizes.drop_front());
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
     for (int64_t d = 0; d < trueDimSize; d++)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 7ed3dea42b771..3d74502951404 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -551,8 +551,8 @@ struct CastAwayConstantMaskLeadingOneDim
 
     int64_t dropDim = oldType.getRank() - newType.getRank();
     SmallVector<int64_t> dimSizes;
-    for (auto attr : mask.getMaskDimSizes())
-      dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+    for (int64_t size : mask.getMaskDimSizes())
+      dimSizes.push_back(size);
 
     // If any of the dropped unit dims has a size of `0`, the entire mask is a
     // zero mask, else the unit dim has no effect on the mask.
@@ -563,7 +563,7 @@ struct CastAwayConstantMaskLeadingOneDim
     newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
 
     auto newMask = rewriter.create<vector::ConstantMaskOp>(
-        mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+        mask.getLoc(), newType, newDimSizes);
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ac2a4d3abcc68..d3296ee38c249 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
                                                     newMaskOperands);
   } else if (constantMaskOp) {
-    ArrayRef<Attribute> maskDimSizes =
-        constantMaskOp.getMaskDimSizes().getValue();
+    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
-    auto origIndex =
-        cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
-    IntegerAttr maskIndexAttr =
-        rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
-    SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
-    newMaskDimSizes.push_back(maskIndexAttr);
-    newMask = rewriter.create<vector::ConstantMaskOp>(
-        loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+    int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+    int64_t maskIndex = (origIndex + scale - 1) / scale;
+    SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+    newMaskDimSizes.push_back(maskIndex);
+    newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                      newMaskDimSizes);
   }
 
   while (!extractOps.empty()) {

This prevents a bunch of boilerplate conversions to/from IntegerAttrs
and int64_ts. Other than that this is a NFC.
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@MacDue
Copy link
Member Author

MacDue commented Jul 29, 2024

If anyone is in the mood for some light refactoring, I think the following ops could get the same treatment:

MultiDimReductionOp
InsertStridedSliceOp
ReshapeOp
ExtractStridedSliceOp

:)

@MacDue MacDue merged commit 0d9b439 into llvm:main Jul 29, 2024
7 checks passed
@MacDue MacDue deleted the less_boilerplate branch July 29, 2024 17:08
MacDue added a commit to MacDue/llvm-project that referenced this pull request Jul 30, 2024
Follow on from llvm#100997. This again removes from boilerplate conversions
to/from IntegerAttr and int64_t (otherwise, this is a NFC).
MacDue added a commit that referenced this pull request Jul 30, 2024
Follow on from #100997. This again removes from boilerplate conversions
to/from IntegerAttr and int64_t (otherwise, this is a NFC).
banach-space pushed a commit to banach-space/llvm-project that referenced this pull request Aug 7, 2024
…lvm#100997)

This prevents a bunch of boilerplate conversions to/from IntegerAttrs
and int64_ts. Other than that this is a NFC.
banach-space pushed a commit to banach-space/llvm-project that referenced this pull request Aug 7, 2024
Follow on from llvm#100997. This again removes from boilerplate conversions
to/from IntegerAttr and int64_t (otherwise, this is a NFC).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants