From eff3f6451cd54055a4d2b6cff29b73ca8e427f8c Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Tue, 23 Jul 2024 13:51:54 +0000 Subject: [PATCH] Review fixups --- .../Transforms/VectorMaskElimination.cpp | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp index 486784a9cf102b3..9ad0de5cadeaee9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp @@ -68,26 +68,28 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter, for (auto [i, dimSize] : unknownDims) { // Compute the lower bound for the unknown dimension (i.e. the smallest // value it could be). - auto lowerBound = + FailureOr dimLowerBound = vector::ScalableValueBoundsConstraintSet::computeScalableBound( dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax, presburger::BoundType::LB); - if (failed(lowerBound)) + if (failed(dimLowerBound)) return failure(); - auto boundSize = lowerBound->getSize(); - if (failed(boundSize)) + auto dimLowerBoundSize = dimLowerBound->getSize(); + if (failed(dimLowerBoundSize)) return failure(); - if (boundSize->scalable) { - // If the lower bound is scalable and >= to the mask dim size then this - // dim is all-true. - if (boundSize->baseSize < maskTypeDimSizes[i]) + if (dimLowerBoundSize->scalable) { + // If the lower bound is scalable and < the mask dim size then this dim is + // not all-true. + if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i]) return failure(); } else { - // If the lower bound is a constant and >= to the _fixed-size_ mask dim - // size then this dim is all-true. + // If the lower bound is a constant: + // - If the mask dim size is scalable then this dim is not all-true. if (maskTypeDimScalableFlags[i]) return failure(); - if (boundSize->baseSize < maskTypeDimSizes[i]) + // - If the lower bound is < the _fixed-size_ mask dim size then this dim + // is not all-true. + if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i]) return failure(); } }