Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] Fix op dominance bug in rewrite pattern #74159

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 1, 2023

Fixes a bug in SplitDeallocWhenNotAliasingAnyOther. This pattern used to generate invalid IR (op dominance error). We never noticed this bug in existing test cases because other patterns and/or foldings were applied afterwards and those rewrites "fixed up" the IR again. (The bug is visible when running mlir-opt -debug.) Also add additional comments to the implementation and simplify the code a bit.

Apart from the fixed dominance error, this change is NFC. Without this change, buffer deallocation tests will fail when running with #74270.

Fixes a bug in `SplitDeallocWhenNotAliasingAnyOther`. This pattern used to generate invalid IR (op dominance error). We never noticed this bug in existing test cases because other patterns and/or foldings were applied afterwards and those rewrites "fixed up" the IR again. Also add additional comments to the implementation and simplify the code a bit.
@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Dec 1, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Fixes a bug in SplitDeallocWhenNotAliasingAnyOther. This pattern used to generate invalid IR (op dominance error). We never noticed this bug in existing test cases because other patterns and/or foldings were applied afterwards and those rewrites "fixed up" the IR again. Also add additional comments to the implementation and simplify the code a bit.

Apart from the fixed dominance error, this change is NFC.


Full diff: https://github.com/llvm/llvm-project/pull/74159.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+31-24)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 7bbdeab3ea1a870..42653517249d664 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -314,44 +314,51 @@ struct SplitDeallocWhenNotAliasingAnyOther
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
+    Location loc = deallocOp.getLoc();
     if (deallocOp.getMemrefs().size() <= 1)
       return failure();
 
-    SmallVector<Value> newMemrefs, newConditions, replacements;
-    DenseSet<Operation *> exceptedUsers;
-    replacements = deallocOp.getUpdatedConditions();
+    SmallVector<Value> remainingMemrefs, remainingConditions;
+    SmallVector<SmallVector<Value>> updatedConditions;
     for (auto [memref, cond] :
          llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      // Check if `memref` can split off into a separate bufferization.dealloc.
       if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
                                    memref, true)) {
-        newMemrefs.push_back(memref);
-        newConditions.push_back(cond);
+        // `memref` alias with other memrefs, do not split off.
+        remainingMemrefs.push_back(memref);
+        remainingConditions.push_back(cond);
         continue;
       }
 
-      auto newDeallocOp = rewriter.create<DeallocOp>(
-          deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
-      replacements = SmallVector<Value>(llvm::map_range(
-          llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
-          [&](auto replAndNew) -> Value {
-            auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
-                                                      std::get<0>(replAndNew),
-                                                      std::get<1>(replAndNew));
-            exceptedUsers.insert(orOp);
-            return orOp.getResult();
-          }));
+      // Create new bufferization.dealloc op for `memref`.
+      auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
+                                                     deallocOp.getRetained());
+      updatedConditions.push_back(
+          llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
     }
 
-    if (newMemrefs.size() == deallocOp.getMemrefs().size())
+    // Fail if no memref was split off.
+    if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
       return failure();
 
-    rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements,
-                               [&](OpOperand &operand) {
-                                 return !exceptedUsers.contains(
-                                     operand.getOwner());
-                               });
-    return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
-                                  rewriter);
+    // Create bufferization.dealloc op for all remaining memrefs.
+    auto newDeallocOp = rewriter.create<DeallocOp>(
+        loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
+
+    // Bit-or all conditions.
+    SmallVector<Value> replacements =
+        llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
+    for (auto additionalConditions : updatedConditions) {
+      assert(replacements.size() == additionalConditions.size() &&
+             "expected same number of updated conditions");
+      for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
+        replacements[i] = rewriter.create<arith::OrIOp>(
+            loc, replacements[i], additionalConditions[i]);
+      }
+    }
+    rewriter.replaceOp(deallocOp, replacements);
+    return success();
   }
 
 private:

Copy link
Member

@maerhart maerhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

@joker-eph
Copy link
Collaborator

We're missing a test don't we?

@matthias-springer
Copy link
Member Author

We're missing a test don't we?

I tried to write one but there was always some other pattern (or the folder), which cleaned up the IR again. Maybe it's not possible to trigger this bug without adding an option to the pass to only apply this one pattern.

I only noticed this issue because I was running with -debug. Maybe we could have a new MLIR flag that triggers the verifier after each pattern application? We could pass this flag in unit tests. (Similar to the test-convergence pass option of the canonicalizer, but this would probably have to be a mlir-opt option.)

@matthias-springer matthias-springer merged commit 3dae97c into llvm:main Dec 5, 2023
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants