diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 925eb80dbe71ec..b96f5c2651bce5 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp : Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$source, AnyType:$acc, - I64ArrayAttr:$reduction_dims)>, + DenseI64ArrayAttr:$reduction_dims)>, Results<(outs AnyType:$dest)> { let summary = "Multi-dimensional reduction operation"; let description = [{ @@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp : SmallVector getReductionMask() { SmallVector res(getSourceVectorType().getRank(), false); - for (auto ia : getReductionDims().getAsRange()) - res[ia.getInt()] = true; + for (int64_t dim : getReductionDims()) + res[dim] = true; return res; } static SmallVector getReductionMask( diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ab4485c37e5e7f..44bd4aa76ffbd6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder, for (const auto &en : llvm::enumerate(reductionMask)) if (en.value()) reductionDims.push_back(en.index()); - build(builder, result, kind, source, acc, - builder.getI64ArrayAttr(reductionDims)); + build(builder, result, kind, source, acc, reductionDims); } OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) { @@ -466,12 +465,14 @@ LogicalResult MultiDimReductionOp::verify() { SmallVector scalableDims; Type inferredReturnType; auto sourceScalableDims = getSourceVectorType().getScalableDims(); - for (auto it : llvm::enumerate(getSourceVectorType().getShape())) - if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { - return llvm::cast(attr).getValue() == it.index(); - })) { - targetShape.push_back(it.value()); - scalableDims.push_back(sourceScalableDims[it.index()]); + for (auto [dimIdx, dimSize] : + llvm::enumerate(getSourceVectorType().getShape())) + if (!llvm::any_of(getReductionDims(), + [dimIdx = dimIdx](int64_t reductionDimIdx) { + return reductionDimIdx == static_cast(dimIdx); + })) { + targetShape.push_back(dimSize); + scalableDims.push_back(sourceScalableDims[dimIdx]); } // TODO: update to also allow 0-d vectors when available. if (targetShape.empty()) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index ac576ed0b4f097..716da55ba09aec 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Separate reduction and parallel dims - auto reductionDimsRange = - multiReductionOp.getReductionDims().getAsValueRange(); - auto reductionDims = llvm::to_vector<4>(llvm::map_range( - reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); + ArrayRef reductionDims = multiReductionOp.getReductionDims(); llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), reductionDims.end()); int64_t reductionSize = reductionDims.size();