From 5f26497da7de10c4eeec33b5a5cfcb47e96836cc Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Sat, 10 Aug 2024 14:10:24 +0100 Subject: [PATCH] [mlir][vector] Use `DenseI64ArrayAttr` in vector.multi_reduction (#102637) This prevents some unnecessary conversions to/from int64_t and IntegerAttr. --- .../include/mlir/Dialect/Vector/IR/VectorOps.td | 6 +++--- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 +++++++++-------- .../Transforms/LowerVectorMultiReduction.cpp | 5 +---- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 925eb80dbe71ec..b96f5c2651bce5 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp : Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$source, AnyType:$acc, - I64ArrayAttr:$reduction_dims)>, + DenseI64ArrayAttr:$reduction_dims)>, Results<(outs AnyType:$dest)> { let summary = "Multi-dimensional reduction operation"; let description = [{ @@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp : SmallVector getReductionMask() { SmallVector res(getSourceVectorType().getRank(), false); - for (auto ia : getReductionDims().getAsRange()) - res[ia.getInt()] = true; + for (int64_t dim : getReductionDims()) + res[dim] = true; return res; } static SmallVector getReductionMask( diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ab4485c37e5e7f..44bd4aa76ffbd6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder, for (const auto &en : llvm::enumerate(reductionMask)) if (en.value()) reductionDims.push_back(en.index()); - build(builder, result, kind, source, acc, - builder.getI64ArrayAttr(reductionDims)); + build(builder, result, kind, source, acc, reductionDims); } OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) { @@ -466,12 +465,14 @@ LogicalResult MultiDimReductionOp::verify() { SmallVector scalableDims; Type inferredReturnType; auto sourceScalableDims = getSourceVectorType().getScalableDims(); - for (auto it : llvm::enumerate(getSourceVectorType().getShape())) - if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { - return llvm::cast(attr).getValue() == it.index(); - })) { - targetShape.push_back(it.value()); - scalableDims.push_back(sourceScalableDims[it.index()]); + for (auto [dimIdx, dimSize] : + llvm::enumerate(getSourceVectorType().getShape())) + if (!llvm::any_of(getReductionDims(), + [dimIdx = dimIdx](int64_t reductionDimIdx) { + return reductionDimIdx == static_cast(dimIdx); + })) { + targetShape.push_back(dimSize); + scalableDims.push_back(sourceScalableDims[dimIdx]); } // TODO: update to also allow 0-d vectors when available. if (targetShape.empty()) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index ac576ed0b4f097..716da55ba09aec 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Separate reduction and parallel dims - auto reductionDimsRange = - multiReductionOp.getReductionDims().getAsValueRange(); - auto reductionDims = llvm::to_vector<4>(llvm::map_range( - reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); + ArrayRef reductionDims = multiReductionOp.getReductionDims(); llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), reductionDims.end()); int64_t reductionSize = reductionDims.size();