-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims #66638
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir ChangesThis extends
Full diff: https://github.com/llvm/llvm-project/pull/66638.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 28b5864914f6920..64fbd722a4f02c3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2248,7 +2248,10 @@ def Vector_ConstantMaskOp :
define a hyper-rectangular region within which elements values are set to 1
(otherwise element values are set to 0). Each value of 'mask_dim_sizes' must
be non-negative and not greater than the size of the corresponding vector
- dimension (as opposed to vector.create_mask which allows this).
+ dimension (as opposed to vector.create_mask which allows this). Sizes that
+ correspond to scalable dimensions are implicitly multiplied by vscale,
+ though currently only zero (none set) or the size of the dim/vscale
+ (all set) are supported.
Example:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a8ad05f7bc1cabf..3c68cb26fb55a11 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5320,13 +5320,18 @@ LogicalResult ConstantMaskOp::verify() {
// Verify that each array attr element is in bounds of corresponding vector
// result dimension size.
auto resultShape = resultType.getShape();
+ auto resultScalableDims = resultType.getScalableDims();
SmallVector<int64_t, 4> maskDimSizes;
- for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
- int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
- if (attrValue < 0 || attrValue > resultShape[it.index()])
+ for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
+ int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+ if (maskDimSize < 0 || maskDimSize > resultShape[index])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
- maskDimSizes.push_back(attrValue);
+ if (resultScalableDims[index] && maskDimSize != 0 &&
+ maskDimSize != resultShape[index])
+ return emitOpError(
+ "only supports 'none set' or 'all set' scalable dimensions");
+ maskDimSizes.push_back(maskDimSize);
}
// Verify that if one mask dim size is zero, they all should be zero (because
// the mask region is a conjunction of each mask dimension interval).
@@ -5335,14 +5340,6 @@ LogicalResult ConstantMaskOp::verify() {
if (anyZeros && !allZeros)
return emitOpError("expected all mask dim sizes to be zeros, "
"as a result of conjunction with zero mask dim");
- // Verify that if the mask type is scalable, dimensions should be zero because
- // constant scalable masks can only be defined for the "none set" or "all set"
- // cases, and there is no VLA way to define an "all set" case for
- // `vector.constant_mask`. In the future, a convention could be established
- // to decide if a specific dimension value could be considered as "all set".
- if (resultType.isScalable() &&
- llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
- return emitOpError("expected mask dim sizes for scalable masks to be 0");
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 9a828ec0b845e4a..418dc6786a76ed4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -105,7 +105,6 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto dstType = op.getType();
- auto eltType = dstType.getElementType();
auto dimSizes = op.getMaskDimSizes();
int64_t rank = dstType.getRank();
@@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
- DenseIntElementsAttr::get(
- VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
- ArrayRef<bool>{value}));
+ DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+ value));
return success();
}
- // Scalable constant masks can only be lowered for the "none set" case.
- if (cast<VectorType>(dstType).isScalable()) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, DenseElementsAttr::get(dstType, false));
- return success();
- }
-
- int64_t trueDim = std::min(dstType.getDimSize(0),
- cast<IntegerAttr>(dimSizes[0]).getInt());
+ int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
if (rank == 1) {
- // Express constant 1-D case in explicit vector form:
- // [T,..,T,F,..,F].
- SmallVector<bool> values(dstType.getDimSize(0));
- for (int64_t d = 0; d < trueDim; d++)
- values[d] = true;
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, dstType, rewriter.getBoolVectorAttr(values));
+ if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+ // Use constant splat for 'all set' or 'none set' dims.
+ // This produces correct code for scalable dimensions.
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, DenseElementsAttr::get(dstType, trueDimSize != 0));
+ } else {
+ // Express constant 1-D case in explicit vector form:
+ // [T,..,T,F,..,F].
+ SmallVector<bool> values(dstType.getDimSize(0));
+ for (int64_t d = 0; d < trueDimSize; d++)
+ values[d] = true;
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, dstType, rewriter.getBoolVectorAttr(values));
+ }
return success();
}
- VectorType lowType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
- SmallVector<int64_t> newDimSizes;
- for (int64_t r = 1; r < rank; r++)
- newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
+ if (dstType.getScalableDims().front())
+ return rewriter.notifyMatchFailure(
+ op, "Cannot unroll leading scalable dim in dstType");
+
+ VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
- loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
+ loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
- for (int64_t d = 0; d < trueDim; d++)
+ for (int64_t d = 0; d < trueDimSize; d++)
result =
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7b29ef44c1f2f2e..27bd5b5ea0eed7b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1819,16 +1819,55 @@ func.func @genbool_1d() -> vector<8xi1> {
// -----
-func.func @genbool_1d_scalable() -> vector<[8]xi1> {
+func.func @genbool_1d_scalable_pfalse() -> vector<[8]xi1> {
%0 = vector.constant_mask [0] : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
-// CHECK-LABEL: func @genbool_1d_scalable
+// CHECK-LABEL: func @genbool_1d_scalable_pfalse
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<[8]xi1>
// CHECK: return %[[VAL_0]] : vector<[8]xi1>
// -----
+func.func @genbool_1d_scalable_ptrue() -> vector<[8]xi1> {
+ %0 = vector.constant_mask [8] : vector<[8]xi1>
+ return %0 : vector<[8]xi1>
+}
+// CHECK-LABEL: func @genbool_1d_scalable_ptrue
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[8]xi1>
+// CHECK: return %[[VAL_0]] : vector<[8]xi1>
+
+// -----
+
+func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
+ %0 = vector.constant_mask [2, 4] : vector<4x[4]xi1>
+ return %0 : vector<4x[4]xi1>
+}
+// CHECK-LABEL: func.func @genbool_2d_scalable() -> vector<4x[4]xi1> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x[4]xi1>
+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x[4]xi1> to !llvm.array<4 x vector<[4]xi1>>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<[4]xi1>>
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<[4]xi1>> to vector<4x[4]xi1>
+// CHECK: return %[[VAL_5]] : vector<4x[4]xi1>
+// CHECK: }
+
+// -----
+
+/// Currently, this is not supported as generating the mask would require
+/// unrolling the leading scalable dimension at compile time.
+func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+ %0 = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+ return %0 : vector<[4]x4xi1>
+}
+// CHECK-LABEL: func.func @cannot_genbool_2d_leading_scalable() -> vector<[4]x4xi1> {
+// CHECK: %[[VAL_0:.*]] = vector.constant_mask [4, 2] : vector<[4]x4xi1>
+// CHECK: return %[[VAL_0]] : vector<[4]x4xi1>
+// CHECK: }
+
+// -----
+
func.func @genbool_2d() -> vector<4x4xi1> {
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
return %v: vector<4x4xi1>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 50119c2b4a36261..26772b929493585 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -995,7 +995,7 @@ func.func @constant_mask_with_zero_mask_dim_size() {
// -----
func.func @constant_mask_scalable_non_zero_dim_size() {
- // expected-error@+1 {{expected mask dim sizes for scalable masks to be 0}}
+ // expected-error@+1 {{only supports 'none set' or 'all set' scalable dimensions}}
%0 = vector.constant_mask [2] : vector<[8]xi1>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4ea4379372e8380..96c56946cd1cfff 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -448,6 +448,10 @@ func.func @constant_vector_mask() {
%0 = vector.constant_mask [3, 2] : vector<4x3xi1>
// CHECK: vector.constant_mask [0] : vector<[4]xi1>
%1 = vector.constant_mask [0] : vector<[4]xi1>
+ // CHECK: vector.constant_mask [4] : vector<[4]xi1>
+ %2 = vector.constant_mask [4] : vector<[4]xi1>
+ // CHECK: vector.constant_mask [1, 4] : vector<2x[4]xi1>
+ %3 = vector.constant_mask [1, 4] : vector<2x[4]xi1>
return
}
@@ -1003,7 +1007,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
%C: vector<3x[8]xf32>,
%M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
// CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
- %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
+ %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
: vector<3x[8]x4xi1> -> vector<3x[8]xf32>
return %0 : vector<3x[8]xf32>
}
|
THanks for the contribution. Really interesting... Is the constant input required to be always lower or equal to the scalable base? For example, what these two examples would mean:
I'm also wondering if it would make sense to distinguish between |
ATM it is required to be "equal".
Indeed, it's unclear and hence I suggest that it's not supported until we figure this out :)
I can't think of a use case for this that couldn't be accommodated with For cases like this one here, it should be sufficient if we have a Vector op that we can use to create an "all true" mask for scalable vectors (instead of using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification. LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ben couple of minor comments but otherwise LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (% minor suggestions), thanks!
newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt()); | ||
if (dstType.getScalableDims().front()) | ||
return rewriter.notifyMatchFailure( | ||
op, "Cannot unroll leading scalable dim in dstType"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to test for this in invalid.mlir, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is a match failure not an invalid op (this is just a fancy return failure()
), it does not produce a diagnostic.
…calable dims This extends vector.constant_mask so that mask dim sizes that correspond to a scalable dimension are treated as if they're implicitly multipled by vscale. Currently this is limited to mask dim sizes of 0 or the size of the dim/vscale. This allows constant masks to represent all true and all false scalable masks (and some variations): // All true scalable mask %mask = vector.constant_mask [8] : vector<[8]xi1> // All false scalable mask %mask = vector.constant_mask [0] : vector<[8]xi1> // First two scalable rows %mask = vector.constant_mask [2,4] : vector<4x[4]xi1>
343fa06
to
91e4da6
Compare
This extends
vector.constant_mask
so that mask dim sizes that correspond to a scalable dimension are treated as if they're implicitly multiplied by vscale. Currently this is limited to mask dim sizes of 0 or the size of the dim/vscale. This allows constant masks to represent all true and all false scalable masks (and some variations):