Skip to content

Commit

Permalink
[mlir][vector] Use DenseI64ArrayAttr in vector.multi_reduction (#10…
Browse files Browse the repository at this point in the history
…2637)

This prevents some unnecessary conversions to/from int64_t and
IntegerAttr.
  • Loading branch information
MacDue authored Aug 10, 2024
1 parent 2849ebb commit 5f26497
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 15 deletions.
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp :
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVector:$source,
AnyType:$acc,
I64ArrayAttr:$reduction_dims)>,
DenseI64ArrayAttr:$reduction_dims)>,
Results<(outs AnyType:$dest)> {
let summary = "Multi-dimensional reduction operation";
let description = [{
Expand Down Expand Up @@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp :

SmallVector<bool> getReductionMask() {
SmallVector<bool> res(getSourceVectorType().getRank(), false);
for (auto ia : getReductionDims().getAsRange<IntegerAttr>())
res[ia.getInt()] = true;
for (int64_t dim : getReductionDims())
res[dim] = true;
return res;
}
static SmallVector<bool> getReductionMask(
Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
for (const auto &en : llvm::enumerate(reductionMask))
if (en.value())
reductionDims.push_back(en.index());
build(builder, result, kind, source, acc,
builder.getI64ArrayAttr(reductionDims));
build(builder, result, kind, source, acc, reductionDims);
}

OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
Expand All @@ -466,12 +465,14 @@ LogicalResult MultiDimReductionOp::verify() {
SmallVector<bool> scalableDims;
Type inferredReturnType;
auto sourceScalableDims = getSourceVectorType().getScalableDims();
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
})) {
targetShape.push_back(it.value());
scalableDims.push_back(sourceScalableDims[it.index()]);
for (auto [dimIdx, dimSize] :
llvm::enumerate(getSourceVectorType().getShape()))
if (!llvm::any_of(getReductionDims(),
[dimIdx = dimIdx](int64_t reductionDimIdx) {
return reductionDimIdx == static_cast<int64_t>(dimIdx);
})) {
targetShape.push_back(dimSize);
scalableDims.push_back(sourceScalableDims[dimIdx]);
}
// TODO: update to also allow 0-d vectors when available.
if (targetShape.empty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion
auto srcRank = multiReductionOp.getSourceVectorType().getRank();

// Separate reduction and parallel dims
auto reductionDimsRange =
multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
auto reductionDims = llvm::to_vector<4>(llvm::map_range(
reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
reductionDims.end());
int64_t reductionSize = reductionDims.size();
Expand Down

0 comments on commit 5f26497

Please sign in to comment.