Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims #66638

Merged
merged 2 commits into from
Sep 20, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Sep 18, 2023

This extends vector.constant_mask so that mask dim sizes that correspond to a scalable dimension are treated as if they're implicitly multiplied by vscale. Currently this is limited to mask dim sizes of 0 or the size of the dim/vscale. This allows constant masks to represent all true and all false scalable masks (and some variations):

// All true scalable mask
%mask = vector.constant_mask [8] : vector<[8]xi1>

// All false scalable mask
%mask = vector.constant_mask [0] : vector<[8]xi1>

// First two scalable rows
%mask = vector.constant_mask [2,4] : vector<4x[4]xi1>

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Changes

This extends vector.constant_mask so that mask dim sizes that correspond to a scalable dimension are treated as if they're implicitly multiplied by vscale. Currently this is limited to mask dim sizes of 0 or the size of the dim/vscale. This allows constant masks to represent all true and all false scalable masks (and some variations):

// All true scalable mask
%mask = vector.constant_mask [8] : vector&lt;[8]xi1&gt;

// All false scalable mask
%mask = vector.constant_mask [0] : vector&lt;[8]xi1&gt;

// First two scalable rows
%mask = vector.constant_mask [2,4] : vector&lt;4x[4]xi1&gt;

Full diff: https://github.com/llvm/llvm-project/pull/66638.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+4-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+9-12)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp (+24-27)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+41-2)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+5-1)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 28b5864914f6920..64fbd722a4f02c3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp :
     define a hyper-rectangular region within which elements values are set to 1
     (otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
     be non-negative and not greater than the size of the corresponding vector
-    dimension (as opposed to vector.create_mask which allows this).
+    dimension (as opposed to vector.create_mask which allows this). Sizes that
+    correspond to scalable dimensions are implicitly multiplied by vscale,
+    though currently only zero (none set) or the size of the dim/vscale
+    (all set) are supported.
 
     Example:
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a8ad05f7bc1cabf..3c68cb26fb55a11 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() {
   // Verify that each array attr element is in bounds of corresponding vector
   // result dimension size.
   auto resultShape = resultType.getShape();
+  auto resultScalableDims = resultType.getScalableDims();
   SmallVector<int64_t, 4> maskDimSizes;
-  for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
-    int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
-    if (attrValue < 0 || attrValue > resultShape[it.index()])
+  for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
+    int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+    if (maskDimSize < 0 || maskDimSize > resultShape[index])
       return emitOpError(
           "array attr of size out of bounds of vector result dimension size");
-    maskDimSizes.push_back(attrValue);
+    if (resultScalableDims[index] && maskDimSize != 0 &&
+        maskDimSize != resultShape[index])
+      return emitOpError(
+          "only supports 'none set' or 'all set' scalable dimensions");
+    maskDimSizes.push_back(maskDimSize);
   }
   // Verify that if one mask dim size is zero, they all should be zero (because
   // the mask region is a conjunction of each mask dimension interval).
@@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() {
   if (anyZeros && !allZeros)
     return emitOpError("expected all mask dim sizes to be zeros, "
                        "as a result of conjunction with zero mask dim");
-  // Verify that if the mask type is scalable, dimensions should be zero because
-  // constant scalable masks can only be defined for the "none set" or "all set"
-  // cases, and there is no VLA way to define an "all set" case for
-  // `vector.constant_mask`. In the future, a convention could be established
-  // to decide if a specific dimension value could be considered as "all set".
-  if (resultType.isScalable() &&
-      llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
-    return emitOpError("expected mask dim sizes for scalable masks to be 0");
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 9a828ec0b845e4a..418dc6786a76ed4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto dstType = op.getType();
-    auto eltType = dstType.getElementType();
     auto dimSizes = op.getMaskDimSizes();
     int64_t rank = dstType.getRank();
 
@@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
       bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
           op, dstType,
-          DenseIntElementsAttr::get(
-              VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
-              ArrayRef<bool>{value}));
+          DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+                                    value));
       return success();
     }
 
-    // Scalable constant masks can only be lowered for the "none set" case.
-    if (cast<VectorType>(dstType).isScalable()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, DenseElementsAttr::get(dstType, false));
-      return success();
-    }
-
-    int64_t trueDim = std::min(dstType.getDimSize(0),
-                               cast<IntegerAttr>(dimSizes[0]).getInt());
+    int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
 
     if (rank == 1) {
-      // Express constant 1-D case in explicit vector form:
-      //   [T,..,T,F,..,F].
-      SmallVector<bool> values(dstType.getDimSize(0));
-      for (int64_t d = 0; d < trueDim; d++)
-        values[d] = true;
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, dstType, rewriter.getBoolVectorAttr(values));
+      if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+        // Use constant splat for 'all set' or 'none set' dims.
+        // This produces correct code for scalable dimensions.
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, DenseElementsAttr::get(dstType, trueDimSize != 0));
+      } else {
+        // Express constant 1-D case in explicit vector form:
+        //   [T,..,T,F,..,F].
+        SmallVector<bool> values(dstType.getDimSize(0));
+        for (int64_t d = 0; d < trueDimSize; d++)
+          values[d] = true;
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, dstType, rewriter.getBoolVectorAttr(values));
+      }
       return success();
     }
 
-    VectorType lowType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
-    SmallVector<int64_t> newDimSizes;
-    for (int64_t r = 1; r < rank; r++)
-      newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
+    if (dstType.getScalableDims().front())
+      return rewriter.notifyMatchFailure(
+          op, "Cannot unroll leading scalable dim in dstType");
+
+    VectorType lowType = VectorType::Builder(dstType).dropDim(0);
     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
-        loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
+        loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
-    for (int64_t d = 0; d < trueDim; d++)
+    for (int64_t d = 0; d < trueDimSize; d++)
       result =
           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
     rewriter.replaceOp(op, result);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7b29ef44c1f2f2e..27bd5b5ea0eed7b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1819,16 +1819,55 @@ func.func @genbool_1d() -> vector<8xi1> {
 
 // -----
 
-func.func @genbool_1d_scalable() -> vector<[8]xi1> {
+func.func @genbool_1d_scalable_pfalse() -> vector<[8]xi1> {
   %0 = vector.constant_mask [0] : vector<[8]xi1>
   return %0 : vector<[8]xi1>
 }
-// CHECK-LABEL: func @genbool_1d_scalable
+// CHECK-LABEL: func @genbool_1d_scalable_pfalse
 // CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
 // CHECK: return %[[VAL_0]] : vector<[8]xi1>
 
 // -----
 
+func.func @genbool_1d_scalable_ptrue() -> vector<[8]xi1> {
+  %0 = vector.constant_mask [8] : vector<[8]xi1>
+  return %0 : vector<[8]xi1>
+}
+// CHECK-LABEL: func @genbool_1d_scalable_ptrue
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[8]xi1>
+// CHECK: return %[[VAL_0]] : vector<[8]xi1>
+
+// -----
+
+func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
+  %0 = vector.constant_mask [2, 4] : vector<4x[4]xi1>
+  return %0 : vector<4x[4]xi1>
+}
+// CHECK-LABEL:   func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x[4]xi1>
+// CHECK:           %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK:           %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1>
+// CHECK:           return %[[VAL_5]] : vector<4x[4]xi1>
+// CHECK:         }
+
+// -----
+
+/// Currently, this is not supported as generating the mask would require
+/// unrolling the leading scalable dimension at compile time.
+func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+  %0 = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+  return %0 : vector<[4]x4xi1>
+}
+// CHECK-LABEL:   func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+// CHECK:           %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+// CHECK:           return %[[VAL_0]] : vector<[4]x4xi1>
+// CHECK:         }
+
+// -----
+
 func.func @genbool_2d() -> vector<4x4xi1> {
   %v = vector.constant_mask [2, 2] : vector<4x4xi1>
   return %v: vector<4x4xi1>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 50119c2b4a36261..26772b929493585 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() {
 // -----
 
 func.func @constant_mask_scalable_non_zero_dim_size() {
-  // expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}}
+  // expected-error@+1 {{only supports 'none set' or 'all set' scalable dimensions}}
   %0 = vector.constant_mask [2] : vector<[8]xi1>
 }
 
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4ea4379372e8380..96c56946cd1cfff 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -448,6 +448,10 @@ func.func @constant_vector_mask() {
   %0 = vector.constant_mask [3, 2] : vector<4x3xi1>
   // CHECK: vector.constant_mask [0] : vector<[4]xi1>
   %1 = vector.constant_mask [0] : vector<[4]xi1>
+  // CHECK: vector.constant_mask [4] : vector<[4]xi1>
+  %2 = vector.constant_mask [4] : vector<[4]xi1>
+  // CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1>
+  %3 = vector.constant_mask [1, 4] : vector<2x[4]xi1>
   return
 }
 
@@ -1003,7 +1007,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
                                     %C: vector<3x[8]xf32>,
                                     %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
  // CHECK:  vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
-  %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } 
+  %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
     : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
   return %0 : vector<3x[8]xf32>
 }

@dcaballe
Copy link
Contributor

THanks for the contribution. Really interesting... Is the constant input required to be always lower or equal to the scalable base?

For example, what these two examples would mean:

%mask = vector.constant_mask [2, 7] : vector<4x[4]xi1>
%mask = vector.constant_mask [2, 3] : vector<4x[4]xi1>

I'm also wondering if it would make sense to distinguish between [2, 4] : vector<4x[4]xi1> and [2, [4]] : vector<4x[4]xi1> as they may mean completely different things?

@banach-space
Copy link
Contributor

Is the constant input required to be always lower or equal to the scalable base?

ATM it is required to be "equal".

For example, what these two examples would mean:

%mask = vector.constant_mask [2, 7] : vector<4x[4]xi1>
%mask = vector.constant_mask [2, 3] : vector<4x[4]xi1>

Indeed, it's unclear and hence I suggest that it's not supported until we figure this out :)

I'm also wondering if it would make sense to distinguish between [2, 4] : vector<4x[4]xi1> and [2, [4]] : vector<4x[4]xi1> as they may mean completely different things?

I can't think of a use case for this that couldn't be accommodated with vector.create_mask.

For cases like this one here, it should be sufficient if we have a Vector op that we can use to create an "all true" mask for scalable vectors (instead of usingarith.constant dense<true> : vector<[8]xi1>). IIUC, this is what this patch strives to achieve. As for more complex cases, we should get a better idea in the coming weeks/months as we continue extending/refining support for SVE and scalable vectors.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. LGTM!

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ben couple of minor comments but otherwise LGTM

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (% minor suggestions), thanks!

newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
if (dstType.getScalableDims().front())
return rewriter.notifyMatchFailure(
op, "Cannot unroll leading scalable dim in dstType");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to test for this in invalid.mlir, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is a match failure not an invalid op (this is just a fancy return failure()), it does not produce a diagnostic.

…calable dims

This extends vector.constant_mask so that mask dim sizes that correspond
to a scalable dimension are treated as if they're implicitly multipled
by vscale. Currently this is limited to mask dim sizes of 0 or the
size of the dim/vscale. This allows constant masks to represent all true
and all false scalable masks (and some variations):

  // All true scalable mask
  %mask = vector.constant_mask [8] : vector<[8]xi1>

  // All false scalable mask
  %mask = vector.constant_mask [0] : vector<[8]xi1>

  // First two scalable rows
  %mask = vector.constant_mask [2,4] : vector<4x[4]xi1>
@MacDue MacDue force-pushed the constant_mask_scalable_5 branch from 343fa06 to 91e4da6 Compare September 20, 2023 11:09
@MacDue MacDue merged commit 2f11ce5 into llvm:main Sep 20, 2023
@MacDue MacDue deleted the constant_mask_scalable_5 branch September 20, 2023 20:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants