Skip to content

Commit

Permalink
[mlir][Linalg] Avoid doing op replacement in linalg::dropUnitDims. (l…
Browse files Browse the repository at this point in the history
…lvm#105749)

It is better to do the replacement in the caller. This avoids the
footgun if the caller needs the original operation. Instead return the
produced operation and replacement values.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Aug 23, 2024
1 parent a2a5508 commit 4dbaef6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,13 @@ struct ControlDropUnitDims {
return SmallVector<unsigned>{};
};
};
LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
const ControlDropUnitDims &options);
struct DropUnitDimsResult {
linalg::GenericOp resultOp;
SmallVector<Value> replacements;
};
FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
GenericOp genericOp,
const ControlDropUnitDims &options);

/// Fuse two `linalg.generic` operations that have a producer-consumer
/// relationship captured through `fusedOperand`. The method expects
Expand Down
16 changes: 11 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,9 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
return info;
}

LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
const ControlDropUnitDims &options) {
FailureOr<DropUnitDimsResult>
linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
const ControlDropUnitDims &options) {
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
if (indexingMaps.empty())
return failure();
Expand Down Expand Up @@ -545,8 +546,7 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
resultReplacements.push_back(expandedValue);
}

rewriter.replaceOp(genericOp, resultReplacements);
return success();
return DropUnitDimsResult{replacementOp, resultReplacements};
}

namespace {
Expand All @@ -557,7 +557,13 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> {

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
return dropUnitDims(rewriter, genericOp, options);
FailureOr<DropUnitDimsResult> result =
dropUnitDims(rewriter, genericOp, options);
if (failed(result)) {
return failure();
}
rewriter.replaceOp(genericOp, result->replacements);
return success();
}

private:
Expand Down
8 changes: 7 additions & 1 deletion mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
linalg::GenericOp genericOp) {
linalg::ControlDropUnitDims options;
options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
return linalg::dropUnitDims(rewriter, genericOp, options);
FailureOr<linalg::DropUnitDimsResult> result =
linalg::dropUnitDims(rewriter, genericOp, options);
if (failed(result)) {
return failure();
}
rewriter.replaceOp(genericOp, result->replacements);
return success();
}

struct TestLinalgDropUnitDims
Expand Down

0 comments on commit 4dbaef6

Please sign in to comment.