Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add mask elimination transform #99314

Merged
merged 7 commits into from
Aug 9, 2024
Merged

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jul 17, 2024

This adds a new transform eliminateVectorMasks() which aims at removing scalable vector.create_masks that will be all-true at runtime. It attempts to do this by simply pattern-matching the mask operands (similar to some canonicalizations), if that does not lead to an answer (is all-true? yes/no), then value bounds analysis will be used to find the lower bound of the unknown operands. If the lower bound is >= to the corresponding mask vector type dim, then that dimension of the mask is all true.

Note that the pattern matching prevents expensive value-bounds analysis in cases where the mask won't be all true.

For example:

%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>

From looking at %c2 we can tell this is not going to be an all-true mask, so we don't need to run the value-bounds analysis for
%dynamicValue (and can exit the transform early).

Note: Eliminating create_masks here means replacing them with all-true constants (which will then lead to the masks folding away).

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 17, 2024

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Benjamin Maxwell (MacDue)

Changes

This adds a new transform eliminateVectorMasks() which aims at removing scalable vector.create_masks that will be all-true at runtime. It attempts to do this by simply pattern-matching the mask operands (similar to some canonicalizations), if that does not lead to an answer (is all-true? yes/no), then value bounds analysis will be used to find the lower bound of the unknown operands. If the lower bound is >= to the corresponding mask vector type dim, then that dimension of the mask is all true.

Note: Eliminating create_masks here means replacing them with all-true constants (which will then lead to the masks folding away).


Full diff: https://github.com/llvm/llvm-project/pull/99314.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h (+17)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp (+117)
  • (added) mlir/test/Dialect/Vector/eliminate-masks.mlir (+138)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+34)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 1f7d6411cd5a4..847f333d6a931 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
 class MLIRContext;
@@ -115,6 +116,22 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                  MaskingOpInterface maskingOp,
                                  RewriterBase &rewriter);
 
+/// Structure to hold the range [vscaleMin, vscaleMax] `vector.vscale` can take.
+struct VscaleRange {
+  unsigned vscaleMin;
+  unsigned vscaleMax;
+};
+
+/// Attempts to eliminate redundant vector masks by replacing them with all-true
+/// constants at the top of the function (which results in the masks folding
+/// away). Note: Currently, this only runs for vector.create_mask ops and
+/// requires `vscaleRange`. If `vscaleRange` is not provided this transform does
+/// nothing. This is because these redundant masks are much more likely for
+/// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size
+/// code has static sizes, so simpler folds remove the masks.
+void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
+                          std::optional<VscaleRange> vscaleRange = {});
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 723b2f62d65d4..2639a67e1c8b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
   VectorUnroll.cpp
+  VectorMaskElimination.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
new file mode 100644
index 0000000000000..abec8c75b8fc9
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
@@ -0,0 +1,117 @@
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+namespace {
+
+/// If `value` is a constant multiple of `vector.vscale` return the multiplier.
+std::optional<int64_t> getConstantVscaleMultiplier(Value value) {
+  if (value.getDefiningOp<vector::VectorScaleOp>())
+    return 1;
+  auto mul = value.getDefiningOp<arith::MulIOp>();
+  if (!mul)
+    return {};
+  auto lhs = mul.getLhs();
+  auto rhs = mul.getRhs();
+  if (lhs.getDefiningOp<vector::VectorScaleOp>())
+    return getConstantIntValue(rhs);
+  if (rhs.getDefiningOp<vector::VectorScaleOp>())
+    return getConstantIntValue(lhs);
+  return {};
+}
+
+/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
+/// All-true masks can then be eliminated by simple folds.
+LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
+                                         vector::CreateMaskOp createMaskOp,
+                                         VscaleRange vscaleRange) {
+  auto maskType = createMaskOp.getVectorType();
+  auto maskTypeDimScalableFlags = maskType.getScalableDims();
+  auto maskTypeDimSizes = maskType.getShape();
+
+  struct UnknownMaskDim {
+    size_t position;
+    Value dimSize;
+  };
+
+  // Check for any dims that could be (partially) false before doing the more
+  // expensive value bounds computations.
+  SmallVector<UnknownMaskDim> unknownDims;
+  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      // Mask not all-true for this dim.
+      if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
+        return failure();
+    } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
+      // Mask not all-true for this dim.
+      if (vscaleMultiplier < maskTypeDimSizes[i])
+        return failure();
+    } else {
+      // Unknown (without further analysis).
+      unknownDims.push_back(UnknownMaskDim{i, dimSize});
+    }
+  }
+
+  for (auto [i, dimSize] : unknownDims) {
+    // Compute the lower bound for the unknown dimension (i.e. the smallest
+    // value it could be).
+    auto lowerBound =
+        vector::ScalableValueBoundsConstraintSet::computeScalableBound(
+            dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
+            presburger::BoundType::LB);
+    if (failed(lowerBound))
+      return failure();
+    auto boundSize = lowerBound->getSize();
+    if (failed(boundSize))
+      return failure();
+    if (boundSize->scalable) {
+      // If the lower bound is scalable and >= to the mask dim size then this
+      // dim is all-true.
+      if (boundSize->baseSize < maskTypeDimSizes[i])
+        return failure();
+    } else {
+      // If the lower bound is a constant and >= to the _fixed-size_ mask dim
+      // size then this dim is all-true.
+      if (maskTypeDimScalableFlags[i])
+        return failure();
+      if (boundSize->baseSize < maskTypeDimSizes[i])
+        return failure();
+    }
+  }
+
+  // Replace createMaskOp with an all-true constant. This should result in the
+  // mask being removed in most cases (as xfer ops + vector.mask have folds to
+  // remove all-true masks).
+  auto allTrue = rewriter.create<arith::ConstantOp>(
+      createMaskOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
+  rewriter.replaceAllUsesWith(createMaskOp, allTrue);
+  return success();
+}
+
+} // namespace
+
+namespace mlir::vector {
+
+void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
+                          std::optional<VscaleRange> vscaleRange) {
+  // TODO: Support fixed-size case. This is less likely to be useful as for
+  // fixed-size code dimensions are all static so masks tend to fold away.
+  if (!vscaleRange)
+    return;
+
+  OpBuilder::InsertionGuard g(rewriter);
+  SmallVector<vector::CreateMaskOp> worklist;
+  function.walk([&](vector::CreateMaskOp createMaskOp) {
+    worklist.push_back(createMaskOp);
+  });
+  rewriter.setInsertionPointToStart(&function.front());
+  for (auto mask : worklist)
+    (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
+}
+
+} // namespace mlir::vector
diff --git a/mlir/test/Dialect/Vector/eliminate-masks.mlir b/mlir/test/Dialect/Vector/eliminate-masks.mlir
new file mode 100644
index 0000000000000..99c9a60a09fac
--- /dev/null
+++ b/mlir/test/Dialect/Vector/eliminate-masks.mlir
@@ -0,0 +1,138 @@
+// RUN: mlir-opt %s -split-input-file -test-eliminate-vector-masks  | FileCheck %s
+
+// This tests a general pattern the vectorizer tends to emit.
+
+// CHECK-LABEL: @eliminate_redundant_masks_through_insert_and_extracts
+// CHECK: %[[ALL_TRUE_MASK:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK: vector.transfer_read {{.*}} %[[ALL_TRUE_MASK]]
+// CHECK: vector.transfer_write {{.*}} %[[ALL_TRUE_MASK]]
+func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor<1x1000xf32>) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c1000 = arith.constant 1000 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  %extracted_slice_0 = tensor.extract_slice %tensor[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x1000xf32> to tensor<1x?xf32>
+  %output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> {
+    // 1. Extract a slice.
+    %extracted_slice_1 = tensor.extract_slice %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
+
+    // 2. Create a mask for the slice.
+    %dim_1 = tensor.dim %extracted_slice_1, %c0 : tensor<?xf32>
+    %mask = vector.create_mask %dim_1 : vector<[4]xi1>
+
+    // 3. Read the slice and do some computation.
+    %vec = vector.transfer_read %extracted_slice_1[%c0], %c0_f32, %mask {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32>
+    %new_vec = "test.some_computation"(%vec) : (vector<[4]xf32>) -> (vector<[4]xf32>)
+
+    // 4. Write the new value.
+    %write = vector.transfer_write %new_vec, %extracted_slice_1[%c0], %mask {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32>
+
+    // 5. Insert and yield the new tensor value.
+    %result = tensor.insert_slice %write into %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
+    scf.yield %result : tensor<1x?xf32>
+  }
+  "test.some_use"(%output_tensor) : (tensor<1x?xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_extract_slice_size_shrink
+// CHECK-NOT: arith.constant dense<true> : vector<[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<[4]xi1>) -> ()
+func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c1000 = arith.constant 1000 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  %extracted_slice = tensor.extract_slice %tensor[0] [%c4_vscale] [1] : tensor<1000xf32> to tensor<?xf32>
+  %slice = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice) -> tensor<?xf32> {
+    // This mask cannot be eliminated even though looking at the above operations
+    // it appears `tensor.dim` will always be c4_vscale (so the mask all-true).
+    %dim = tensor.dim %arg, %c0 : tensor<?xf32>
+    %mask = vector.create_mask %dim : vector<[4]xi1>
+    "test.some_use"(%mask) : (vector<[4]xi1>) -> ()
+    // !!! Here the size of the mask could shrink in the next iteration.
+    %next_num_els = affine.min  affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale]
+    %new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_els] [1] : tensor<1000xf32> to tensor<?xf32>
+    scf.yield %new_extracted_slice : tensor<?xf32>
+  }
+  "test.some_use"(%slice) : (tensor<?xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_constant_dim_not_all_true
+// CHECK-NOT: arith.constant dense<true> : vector<2x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> ()
+func.func @negative_constant_dim_not_all_true()
+{
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  %mask = vector.create_mask %c1, %c4_vscale : vector<2x[4]xi1>
+  "test.some_use"(%mask) : (vector<2x[4]xi1>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_constant_vscale_multiple_not_all_true
+// CHECK-NOT: arith.constant dense<true> : vector<2x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> ()
+func.func @negative_constant_vscale_multiple_not_all_true() {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %vscale = vector.vscale
+  %c3_vscale = arith.muli %vscale, %c3 : index
+  %mask = vector.create_mask %c2, %c3_vscale : vector<2x[4]xi1>
+  "test.some_use"(%mask) : (vector<2x[4]xi1>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_value_bounds_fixed_dim_not_all_true
+// CHECK-NOT: arith.constant dense<true> : vector<3x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> ()
+func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>)
+{
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  // This is _very_ simple but since addi is not a constant value bounds will
+  // be used to resolve it.
+  %dim = tensor.dim %tensor, %c0 : tensor<2x?xf32>
+  %mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1>
+  "test.some_use"(%mask) : (vector<3x[4]xi1>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_value_bounds_scalable_dim_not_all_true
+// CHECK-NOT: arith.constant dense<true> : vector<3x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> ()
+func.func @negative_value_bounds_scalable_dim_not_all_true(%tensor: tensor<2x100xf32>) {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %vscale = vector.vscale
+  %c3_vscale = arith.muli %vscale, %c3 : index
+  %slice = tensor.extract_slice %tensor[0, 0] [2, %c3_vscale] [1, 1] : tensor<2x100xf32> to tensor<2x?xf32>
+  // Another simple example, but value bounds will be used to resolve the tensor.dim.
+  %dim = tensor.dim %slice, %c1 : tensor<2x?xf32>
+  %mask = vector.create_mask %c3, %dim : vector<3x[4]xi1>
+  "test.some_use"(%mask) : (vector<3x[4]xi1>) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index c978699e179fc..f74ff2725f815 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -874,6 +874,38 @@ struct TestVectorLinearize final
       return signalPassFailure();
   }
 };
+
+struct TestEliminateVectorMasks
+    : public PassWrapper<TestEliminateVectorMasks,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks)
+
+  TestEliminateVectorMasks() = default;
+  TestEliminateVectorMasks(const TestEliminateVectorMasks &pass)
+      : PassWrapper(pass) {}
+
+  Option<unsigned> vscaleMin{
+      *this, "vscale-min",
+      llvm::cl::desc(
+          "Minimum value `vector.vscale` can possibly be at runtime."),
+      llvm::cl::init(1)};
+
+  Option<unsigned> vscaleMax{
+      *this, "vscale-max",
+      llvm::cl::desc(
+          "Maximum value `vector.vscale` can possibly be at runtime."),
+      llvm::cl::init(16)};
+
+  StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
+  StringRef getDescription() const final {
+    return "Test eliminating vector masks";
+  }
+  void runOnOperation() override {
+    IRRewriter rewriter(&getContext());
+    eliminateVectorMasks(rewriter, getOperation(),
+                         VscaleRange{vscaleMin, vscaleMax});
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -920,6 +952,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorEmulateMaskedLoadStore>();
 
   PassRegistration<TestVectorLinearize>();
+
+  PassRegistration<TestEliminateVectorMasks>();
 }
 } // namespace test
 } // namespace mlir

@MacDue MacDue force-pushed the mask_elim branch 2 times, most recently from e8ff88d to 7af1229 Compare July 19, 2024 10:02
@dcaballe
Copy link
Contributor

Thanks! A couple of high-level comments. I'll come back :)

  1. Could we move the trivial pattern-matching cases to the folder/canonicalizer? We really don't want to have any trivial all-one mask around both for the fixed and scalable cases. It makes sense to have a more sophisticated and expensive pass to get rid of the tricky cases that require value bounds or range analysis.

  2. Could we use constant_mask instead of arith.constant? This is not enforced in any way but we've been using the former to represent vector masks (they are kind of redundant, actually).

@MacDue
Copy link
Member Author

MacDue commented Jul 22, 2024

  1. The simple pattern matching is done while collecting the unknown dimensions, and functions as an early-exit, as if we can tell that a mask is not going to be all-true we can bail out before doing any more expensive value-bounds analysis. Patterns like this already exist as canonizations, but just give-up on unknown dims rather than trying to resolve them.

// Check for any dims that could be (partially) false before doing the more
// expensive value bounds computations.
SmallVector<UnknownMaskDim> unknownDims;
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto intSize = getConstantIntValue(dimSize)) {
// Mask not all-true for this dim.
if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
return failure();
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
// Mask not all-true for this dim.
if (vscaleMultiplier < maskTypeDimSizes[i])
return failure();
} else {
// Unknown (without further analysis).
unknownDims.push_back(UnknownMaskDim{i, dimSize});
}
}

  1. I think I'd need to add a nice builder for making all-true/false constant masks first, currently, it's much less ergonomic to build a constant_mask vs an arith.constant, since the only builder for constant_masks takes an ArrayAttr... So you have to build an array of IntegerAttr, which is a bit of a pain.

@dcaballe
Copy link
Contributor

The simple pattern matching is done while collecting the unknown dimensions, and functions as an early-exit, as if we can tell that a mask is not going to be all-true we can bail out before doing any more expensive value-bounds analysis. Patterns like this already exist as canonizations, but just give-up on unknown dims rather than trying to resolve them

Sorry, I'm not sure what this comment implies but, yes, I'm meant that the extra pattern matching should be part of canonicalization. If you also need those patterns here to make the pass reach a fixed point, it should be fine to also populate them. As noted, I would expect this to also simplify fixed-length vectors so if the existing canonicalization ones also include fixed-length pattern it should be ok to also have them here.

@MacDue
Copy link
Member Author

MacDue commented Jul 23, 2024

It means value bounds analysis would be needlessly run on CreateMask operations that from simple inspection are not all-true masks.

Say you had create_mask %dynamic, %c2 : vector<4x[4]xi1>. There's no canonicalization that will change that, however it's obvious from simply pattern matching the second dim that this won't be an all-true mask. So by checking these cases while we look for unknown dims we can bail out before doing value bounds analysis. This check is like what the canonicalizations do, but actually using the canonicalizations won't give an early exit for obviously not all-true cases.

@banach-space
Copy link
Contributor

  • Could we move the trivial pattern-matching cases to the folder/canonicalizer?

+1 IIUC, Diego is asking for more code re-use. In particular, if similar logic is already available somewhere in folder/conanicalization "landscape", it should be re-used here:

struct UnknownMaskDim {
    size_t position;
    Value dimSize;
};

// Return `Failure` if at least one Mask dim is statically known not to be "all true". Otherwise,
// returns a vector of `UnknownMaskDim` yet to be analysed to see whether these are "all-true" lanes.
// If the return value is an empty vector, all Mask dims are statically known to be "all-true".
FailureOr<SmallVector<UnknownMaskDim>> getPossiblyTrueMaskDim() {
  SmallVector<UnknownMaskDim> unknownDims;
  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
    if (auto intSize = getConstantIntValue(dimSize)) {
      // Mask not all-true for this dim.
      if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
        return failure();
    } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
      // Mask not all-true for this dim.
      if (vscaleMultiplier < maskTypeDimSizes[i])
        return failure();
    } else {
      // Unknown (without further analysis).
      unknownDims.push_back(UnknownMaskDim{i, dimSize});
    }
  }
  return unknownDims;
}

From what I can tell, you can wrap this into an utilility funciton ,e.g.

SmallVector<>
  • Could we use constant_mask instead of arith.constant? This is not enforced in any way but we've been using the former to represent vector masks (they are kind of redundant, actually).

+1 to this suggestion. @dcaballe , are you OK to implement this as a follow-up?

@MacDue
Copy link
Member Author

MacDue commented Jul 29, 2024

+1 IIUC, Diego is asking for more code re-use. In particular, if similar logic is already available somewhere in folder/conanicalization "landscape", it should be re-used here:

There's nowhere (other than here) that wants to check for not all-true create masks, and collect unknown dimensions 🙂 I was just saying the logic is similar, but not really to a reusable extent imo.

@banach-space
Copy link
Contributor

+1 IIUC, Diego is asking for more code re-use. In particular, if similar logic is already available somewhere in folder/conanicalization "landscape", it should be re-used here:

There's nowhere (other than here) that wants to check for not all-true create masks, and collect unknown dimensions 🙂 I was just saying the logic is similar, but not really to a reusable extent imo.

Then you could still add a helper somewhere in Vector/Utils.

@dcaballe
Copy link
Contributor

dcaballe commented Aug 2, 2024

What I'm suggesting here is:

  1. Move lightweight mask removal patterns to the op canonicalization. We want to remove trivial masks as part of the canonicalization. Create a populateLightweightMaskRemovalPatterns for them.
  2. Keep the heavyweight mask removal patterns in the pass and also add the patterns in populateLightweightMaskRemovalPatterns for completeness.

In that way, we would have the trivial ones as part of canonicalization and when more heavyweight analysis is needed to remove the masks, we will use the pass. Does it make more sense now or am I missing something?

@MacDue
Copy link
Member Author

MacDue commented Aug 2, 2024

Does it make more sense now or am I missing something?

I think you are misunderstanding how this patch works. There are no lightweight canonicalizations added here (or any patterns really). Such lightweight canonicalizations already exist, and this code does not re-implement them, I was simply saying this code follows a similar idea.

Let's look at the actual code here:

  // Check for any dims that could be (partially) false before doing the more
  // expensive value bounds computations.
  SmallVector<UnknownMaskDim> unknownDims;
  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
    if (auto intSize = getConstantIntValue(dimSize)) {
      // Mask not all-true for this dim.
      if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
        return failure();
    } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
      // Mask not all-true for this dim.
      if (vscaleMultiplier < maskTypeDimSizes[i])
        return failure();
    } else {
      // Unknown (without further analysis).
      unknownDims.push_back(UnknownMaskDim{i, dimSize});
    }
  }

This is the part likened to canonicalization, but it's not a canonicalization. It's looking at all the operations of the create_mask and finding the unknown dimensions. Those are the dimensions that need to be solved via value-bounds analysis. The trick here is if any dimension is constant and not all-true, then we can exit early as that means the mask won't be all-true.

This prevents us from pointlessly doing value-bounds analysis is in cases like:

%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>

From looking at %c2 we can tell this is not going to be an all-true mask, so we don't need to run the value-bounds analysis for
%dynamicValue (and will exit the transform early). Note that this create_mask still is not a constant_mask or an all-true/false mask, so no canonicalization will remove it before the "heavyweight" mask removal.

After that we go over the unknown dimensions and solve for them using value-bounds:

  for (auto [i, dimSize] : unknownDims) {
    // Compute the lower bound for the unknown dimension (i.e. the smallest
    // value it could be).
    FailureOr<ConstantOrScalableBound> dimLowerBound =
        vector::ScalableValueBoundsConstraintSet::computeScalableBound(
            dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
            presburger::BoundType::LB);
    if (failed(dimLowerBound))
      return failure();
    auto dimLowerBoundSize = dimLowerBound->getSize();
    if (failed(dimLowerBoundSize))
      return failure();
    if (dimLowerBoundSize->scalable) {
      // If the lower bound is scalable and < the mask dim size then this dim is
      // not all-true.
      if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
        return failure();
    } else {
      // If the lower bound is a constant:
      // - If the mask dim size is scalable then this dim is not all-true.
      if (maskTypeDimScalableFlags[i])
        return failure();
      // - If the lower bound is < the _fixed-size_ mask dim size then this dim
      // is not all-true.
      if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
        return failure();
    }
  }

The TLDR is this is just the "heavyweight" mask removal with a check to avoid doing pointless work.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to press "send" before.

mlir/test/Dialect/Vector/eliminate-masks.mlir Outdated Show resolved Hide resolved
mlir/test/Dialect/Vector/eliminate-masks.mlir Outdated Show resolved Hide resolved
@dcaballe
Copy link
Contributor

dcaballe commented Aug 3, 2024

Thanks for the clarifications!

I think you are misunderstanding how this patch works. There are no lightweight canonicalizations added here (or any patterns really). Such lightweight canonicalizations already exist, and this code does not re-implement them, I was simply saying this code follows a similar idea.

Sorry if that is the case but I'm not sure how else I could interpret this:

It attempts to do this by simply pattern-matching the mask operands (similar to some canonicalizations), if that does not lead to an answer (is all-true? yes/no), then value bounds analysis will be used to find the lower bound of the unknown operands.

The lightweight part that I'm referring as a canonicalization, based on this sentence, is the "pattern-matching the mask operands" that leads to an all-true answer without using the value bounds analysis.

More specifically:

Let's look at the actual code here:

My understanding is that if unknownDims is empty after that loop and no failure() path is taken, we are in an all-true mask case that do not require value bounds analysis (i.e., the lightweight case). If the answer to this is that this canonicalization already exists, then great! That's basically what I've been asking for :).

@MacDue
Copy link
Member Author

MacDue commented Aug 5, 2024

My understanding is that if unknownDims is empty after that loop

That simply works as it would be more code to disable it. But there's already canonicalization of create_mask -> constant_mask that would handle those cases.

@MacDue
Copy link
Member Author

MacDue commented Aug 5, 2024

I've now added a simple builder for making true/false constant_masks (after simplifying the representation in #100997).

  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
      createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);

This adds a new transform `eliminateVectorMasks()` which aims at
removing scalable `vector.create_masks` that will be all-true at
runtime. It attempts to do this by simply pattern-matching the mask
operands (similar to some canonicalizations), if that does not lead to
an answer (is all-true? yes/no), then value bounds analysis will be used
to find the lower bound of the unknown operands. If the lower bound is
>= to the corresponding mask vector type dim, then that dimension of the
mask is all true.

Note: Eliminating create_masks here means replacing them with all-true
constants (which will then lead to the masks folding away).
The main thing shared here is the `getConstantVscaleMultiplier()`
matcher, I could not think of a good way to share all the logic as it's
somewhat different.
@MacDue
Copy link
Member Author

MacDue commented Aug 6, 2024

In the last commit, I've shared some logic with the CreateMaskFolder. The main thing shared here is the new getConstantVscaleMultiplier() matcher. I could not think of a good way to share all the logic, as it's a little different in both places.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the refactor and code re-use. The split between CreateMaskFolder and eliminateVectorMasks is much clearer now.

It would be good to briefly contrast what you are adding here with the current folder in the summary. That's where most people will be looking for documentation.

I believe that this addresses all of Diego's concerns. I know that he's OOO over the coming weeks and it would be really good to land this sooner rather than later. I will accept this (% my comments re documentation) and let's land this. If there's more suggestions from Diego post-commit, we can address those in follow-up PRs.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp Outdated Show resolved Hide resolved
@MacDue MacDue merged commit 9b06e25 into llvm:main Aug 9, 2024
8 checks passed
@MacDue MacDue deleted the mask_elim branch August 9, 2024 09:51
kutemeikito added a commit to kutemeikito/llvm-project that referenced this pull request Aug 10, 2024
* 'main' of https://github.com/llvm/llvm-project: (700 commits)
  [SandboxIR][NFC] SingleLLVMInstructionImpl class (llvm#102687)
  [ThinLTO]Clean up 'import-assume-unique-local' flag. (llvm#102424)
  [nsan] Make #include more conventional
  [SandboxIR][NFC] Use Tracker.emplaceIfTracking()
  [libc]  Moved range_reduction_double ifdef statement (llvm#102659)
  [libc] Fix CFP long double and add tests (llvm#102660)
  [TargetLowering] Handle vector types in expandFixedPointMul (llvm#102635)
  [compiler-rt][NFC] Replace environment variable with %t (llvm#102197)
  [UnitTests] Convert a test to use opaque pointers (llvm#102668)
  [CodeGen][NFCI] Don't re-implement parts of ASTContext::getIntWidth (llvm#101765)
  [SandboxIR] Clean up tracking code with the help of emplaceIfTracking() (llvm#102406)
  [mlir][bazel] remove extra blanks in mlir-tblgen test
  [NVPTX][NFC] Update tests to use bfloat type (llvm#101493)
  [mlir] Add support for parsing nested PassPipelineOptions (llvm#101118)
  [mlir][bazel] add missing td dependency in mlir-tblgen test
  [flang][cuda] Fix lib dependency
  [libc] Clean up remaining use of *_WIDTH macros in printf (llvm#102679)
  [flang][cuda] Convert cuf.alloc for box to fir.alloca in device context (llvm#102662)
  [SandboxIR] Implement the InsertElementInst class (llvm#102404)
  [libc] Fix use of cpp::numeric_limits<...>::digits (llvm#102674)
  [mlir][ODS] Verify type constraints in Types and Attributes (llvm#102326)
  [LTO] enable `ObjCARCContractPass` only on optimized build  (llvm#101114)
  [mlir][ODS] Consistent `cppType` / `cppClassName` usage (llvm#102657)
  [lldb] Move definition of SBSaveCoreOptions dtor out of header (llvm#102539)
  [libc] Use cpp::numeric_limits in preference to C23 <limits.h> macros (llvm#102665)
  [clang] Implement -fptrauth-auth-traps. (llvm#102417)
  [LLVM][rtsan] rtsan transform to preserve CFGAnalyses (llvm#102651)
  Revert "[AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)"
  [RISCV][GISel] Add missing tests for G_CTLZ/CTTZ instruction selection. NFC
  Return available function types for BindingDecls. (llvm#102196)
  [clang] Wire -fptrauth-returns to "ptrauth-returns" fn attribute. (llvm#102416)
  [RISCV] Remove riscv-experimental-rv64-legal-i32. (llvm#102509)
  [RISCV] Move PseudoVSET(I)VLI expansion to use PseudoInstExpansion. (llvm#102496)
  [NVPTX] support switch statement with brx.idx (reland) (llvm#102550)
  [libc][newhdrgen]sorted function names in yaml (llvm#102544)
  [GlobalIsel] Combine G_ADD and G_SUB with constants (llvm#97771)
  Suppress spurious warnings due to R_RISCV_SET_ULEB128
  [scudo] Separated committed and decommitted entries. (llvm#101409)
  [MIPS] Fix missing ANDI optimization (llvm#97689)
  [Clang] Add env var for nvptx-arch/amdgpu-arch timeout (llvm#102521)
  [asan] Switch allocator to dynamic base address (llvm#98511)
  [AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)
  [libc][math][c23] Add fadd{l,f128} C23 math functions (llvm#102531)
  [mlir][bazel] revert bazel rule change for DLTITransformOps
  [msan] Support vst{2,3,4}_lane instructions (llvm#101215)
  Revert "[MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)"
  [X86] pr57673.ll - generate MIR test checks
  [mlir][vector][test] Split tests from vector-transfer-flatten.mlir (llvm#102584)
  [mlir][bazel] add bazel rule for DLTITransformOps
  OpenMPOpt: Remove dead include
  [IR] Add method to GlobalVariable to change type of initializer. (llvm#102553)
  [flang][cuda] Force default allocator in device code (llvm#102238)
  [llvm] Construct SmallVector<SDValue> with ArrayRef (NFC) (llvm#102578)
  [MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)
  [AMDGPU][AsmParser][NFC] Remove a misleading comment. (llvm#102604)
  [Arm][AArch64][Clang] Respect function's branch protection attributes. (llvm#101978)
  [mlir] Verifier: steal bit to track seen instead of set. (llvm#102626)
  [Clang] Fix Handling of Init Capture with Parameter Packs in LambdaScopeForCallOperatorInstantiationRAII (llvm#100766)
  [X86] Convert truncsat clamping patterns to use SDPatternMatch. NFC.
  [gn] Give two scripts argparse.RawDescriptionHelpFormatter
  [bazel] Add missing dep for the SPIRVToLLVM target
  [Clang] Simplify specifying passes via -Xoffload-linker (llvm#102483)
  [bazel] Port for d45de80
  [SelectionDAG] Use unaligned store/load to move AVX registers onto stack for `insertelement` (llvm#82130)
  [Clang][OMPX] Add the code generation for multi-dim `num_teams` (llvm#101407)
  [ARM] Regenerate big-endian-vmov.ll. NFC
  [AMDGPU][AsmParser][NFCI] All NamedIntOperands to be of the i32 type. (llvm#102616)
  [libc][math][c23] Add totalorderl function. (llvm#102564)
  [mlir][spirv] Support `memref` in `convert-to-spirv` pass (llvm#102534)
  [MLIR][GPU-LLVM] Convert `gpu.func` to `llvm.func` (llvm#101664)
  Fix a unit test input file (llvm#102567)
  [llvm-readobj][COFF] Dump hybrid objects for ARM64X files. (llvm#102245)
  AMDGPU/NewPM: Port SIFixSGPRCopies to new pass manager (llvm#102614)
  [MemoryBuiltins] Simplify getCalledFunction() helper (NFC)
  [AArch64] Add invalid 1 x vscale costs for reductions and reduction-operations. (llvm#102105)
  [MemoryBuiltins] Handle allocator attributes on call-site
  LSV/test/AArch64: add missing lit.local.cfg; fix build (llvm#102607)
  Revert "Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)"
  [RISCV] Add Syntacore SCR5 RV32/64 processors definition (llvm#102285)
  [InstCombine] Remove unnecessary RUN line from test (NFC)
  [flang][OpenMP] Handle multiple ranges in `num_teams` clause (llvm#102535)
  [mlir][vector] Add tests for scalable vectors in one-shot-bufferize.mlir (llvm#102361)
  [mlir][vector] Disable `vector.matrix_multiply` for scalable vectors (llvm#102573)
  [clang] Implement CWG2627 Bit-fields and narrowing conversions (llvm#78112)
  [NFC] Use references to avoid copying (llvm#99863)
  Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (llvm#100731)" (llvm#102457)
  [IRBuilder] Generate nuw GEPs for struct member accesses (llvm#99538)
  [bazel] Port for 9b06e25
  [CodeGen][NewPM] Improve start/stop pass error message CodeGenPassBuilder (llvm#102591)
  [AArch64] Implement TRBMPAM_EL1 system register (llvm#102485)
  [InstCombine] Fixing wrong select folding in vectors with undef elements (llvm#102244)
  [AArch64] Sink operands to fmuladd. (llvm#102297)
  LSV: document hang reported in llvm#37865 (llvm#102479)
  Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)
  [RISCV][clang] Remove bfloat base type in non-zvfbfmin vcreate (llvm#102146)
  [RISCV][clang] Add missing `zvfbfmin` to `vget_v` intrinsic (llvm#102149)
  [mlir][vector] Add mask elimination transform (llvm#99314)
  [Clang][Interp] Fix display of syntactically-invalid note for member function calls (llvm#102170)
  [bazel] Port for 3fffa6d
  [DebugInfo][RemoveDIs] Use iterator-inserters in clang (llvm#102006)
  ...

Signed-off-by: Edwiin Kusuma Jaya <kutemeikito0905@gmail.com>
c-rhodes pushed a commit to iree-org/iree that referenced this pull request Aug 14, 2024
…18190)

This enables an upstream transform that eliminates all true
`vector.create_mask` ops. This is particularly beneficial for scalable
vectors, which use dynamic tensor types, which results in masks that
otherwise would not fold away till much later, preventing some
optimizations.

Depends on llvm/llvm-project#99314.

---------

Signed-off-by: Benjamin Maxwell <benjamin.maxwell@arm.com>
bwendling pushed a commit to bwendling/llvm-project that referenced this pull request Aug 15, 2024
This adds a new transform `eliminateVectorMasks()` which aims at
removing scalable `vector.create_masks` that will be all-true at
runtime. It attempts to do this by simply pattern-matching the mask
operands (similar to some canonicalizations), if that does not lead to
an answer (is all-true? yes/no), then value bounds analysis will be used
to find the lower bound of the unknown operands. If the lower bound is
>= to the corresponding mask vector type dim, then that dimension of the
mask is all true.

Note that the pattern matching prevents expensive value-bounds analysis
in cases where the mask won't be all true.

For example:
```mlir
%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>
```
From looking at `%c2` we can tell this is not going to be an all-true
mask, so we don't need to run the value-bounds analysis for
`%dynamicValue` (and can exit the transform early).

Note: Eliminating create_masks here means replacing them with all-true
constants (which will then lead to the masks folding away).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants