Skip to content

Commit

Permalink
[mlir][bufferization] Fix op dominance bug in rewrite pattern (#74159)
Browse files Browse the repository at this point in the history
Fixes a bug in `SplitDeallocWhenNotAliasingAnyOther`. This pattern used
to generate invalid IR (op dominance error). We never noticed this bug
in existing test cases because other patterns and/or foldings were
applied afterwards and those rewrites "fixed up" the IR again. (The bug
is visible when running `mlir-opt -debug`.) Also add additional comments
to the implementation and simplify the code a bit.

Apart from the fixed dominance error, this change is NFC. Without this
change, buffer deallocation tests will fail when running with #74270.
  • Loading branch information
matthias-springer authored Dec 5, 2023
1 parent 4288fb8 commit 3dae97c
Showing 1 changed file with 31 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,44 +314,51 @@ struct SplitDeallocWhenNotAliasingAnyOther

LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
Location loc = deallocOp.getLoc();
if (deallocOp.getMemrefs().size() <= 1)
return failure();

SmallVector<Value> newMemrefs, newConditions, replacements;
DenseSet<Operation *> exceptedUsers;
replacements = deallocOp.getUpdatedConditions();
SmallVector<Value> remainingMemrefs, remainingConditions;
SmallVector<SmallVector<Value>> updatedConditions;
for (auto [memref, cond] :
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
// Check if `memref` can split off into a separate bufferization.dealloc.
if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
memref, true)) {
newMemrefs.push_back(memref);
newConditions.push_back(cond);
// `memref` alias with other memrefs, do not split off.
remainingMemrefs.push_back(memref);
remainingConditions.push_back(cond);
continue;
}

auto newDeallocOp = rewriter.create<DeallocOp>(
deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
replacements = SmallVector<Value>(llvm::map_range(
llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
[&](auto replAndNew) -> Value {
auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
std::get<0>(replAndNew),
std::get<1>(replAndNew));
exceptedUsers.insert(orOp);
return orOp.getResult();
}));
// Create new bufferization.dealloc op for `memref`.
auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
deallocOp.getRetained());
updatedConditions.push_back(
llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
}

if (newMemrefs.size() == deallocOp.getMemrefs().size())
// Fail if no memref was split off.
if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
return failure();

rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
[&](OpOperand &operand) {
return !exceptedUsers.contains(
operand.getOwner());
});
return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
rewriter);
// Create bufferization.dealloc op for all remaining memrefs.
auto newDeallocOp = rewriter.create<DeallocOp>(
loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());

// Bit-or all conditions.
SmallVector<Value> replacements =
llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
for (auto additionalConditions : updatedConditions) {
assert(replacements.size() == additionalConditions.size() &&
"expected same number of updated conditions");
for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
replacements[i] = rewriter.create<arith::OrIOp>(
loc, replacements[i], additionalConditions[i]);
}
}
rewriter.replaceOp(deallocOp, replacements);
return success();
}

private:
Expand Down

0 comments on commit 3dae97c

Please sign in to comment.