Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Complex] Fix bug in MergeComplexBitcast #74271

Conversation

matthias-springer
Copy link
Member

When two complex.bitcast ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an arith.bitcast should be generated. Otherwise, the generated complex.bitcast op is invalid.

Also remove a pattern that convertes non-complex -> non-complex complex.bitcast ops to arith.bitcast. Such complex.bitcast ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with -debug (which will should intermediate IR that does not verify) or with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS (#74270).

When two `complex.bitcast` ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an `arith.bitcast` should be generated. Otherwise, the generated `complex.bitcast` op is invalid.

Also remove a pattern that convertes non-complex -> non-complex `complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with `-debug` (which will should intermediate IR that does not verify) or with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (llvm#74270).
@llvmbot llvmbot added mlir mlir:complex MLIR complex dialect labels Dec 4, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-mlir-complex

Author: Matthias Springer (matthias-springer)

Changes

When two complex.bitcast ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an arith.bitcast should be generated. Otherwise, the generated complex.bitcast op is invalid.

Also remove a pattern that convertes non-complex -> non-complex complex.bitcast ops to arith.bitcast. Such complex.bitcast ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with -debug (which will should intermediate IR that does not verify) or with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS (#74270).


Full diff: https://github.com/llvm/llvm-project/pull/74271.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (+12-19)
  • (modified) mlir/test/Dialect/Complex/invalid.mlir (+1-1)
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
   }
 
   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
-    return emitOpError("requires input or output is a complex type");
+    return emitOpError(
+        "requires that either input or output has a complex type");
   }
 
   if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
   LogicalResult matchAndRewrite(BitcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
-      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
-                                             defining.getOperand());
+      if (isa<ComplexType>(op.getType()) ||
+          isa<ComplexType>(defining.getOperand().getType())) {
+        // complex.bitcast requires that input or output is complex.
+        rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                               defining.getOperand());
+      } else {
+        rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+                                                      defining.getOperand());
+      }
       return success();
     }
 
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
   }
 };
 
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
-  using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BitcastOp op,
-                                PatternRewriter &rewriter) const override {
-    if (isa<ComplexType>(op.getType()) ||
-        isa<ComplexType>(op.getOperand().getType()))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
-                                                  op.getOperand());
-    return success();
-  }
-};
-
 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
 // -----
 
 func.func @complex_bitcast_i64(%arg0 : i64) {
-  // expected-error @+1 {{op requires input or output is a complex type}}
+  // expected-error @+1 {{op requires that either input or output has a complex type}}
   %0 = complex.bitcast %arg0: i64 to f64
   return
 }

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

When two complex.bitcast ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an arith.bitcast should be generated. Otherwise, the generated complex.bitcast op is invalid.

Also remove a pattern that convertes non-complex -> non-complex complex.bitcast ops to arith.bitcast. Such complex.bitcast ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with -debug (which will should intermediate IR that does not verify) or with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS (#74270).


Full diff: https://github.com/llvm/llvm-project/pull/74271.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (+12-19)
  • (modified) mlir/test/Dialect/Complex/invalid.mlir (+1-1)
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
   }
 
   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
-    return emitOpError("requires input or output is a complex type");
+    return emitOpError(
+        "requires that either input or output has a complex type");
   }
 
   if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
   LogicalResult matchAndRewrite(BitcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
-      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
-                                             defining.getOperand());
+      if (isa<ComplexType>(op.getType()) ||
+          isa<ComplexType>(defining.getOperand().getType())) {
+        // complex.bitcast requires that input or output is complex.
+        rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                               defining.getOperand());
+      } else {
+        rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+                                                      defining.getOperand());
+      }
       return success();
     }
 
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
   }
 };
 
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
-  using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BitcastOp op,
-                                PatternRewriter &rewriter) const override {
-    if (isa<ComplexType>(op.getType()) ||
-        isa<ComplexType>(op.getOperand().getType()))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
-                                                  op.getOperand());
-    return success();
-  }
-};
-
 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
 // -----
 
 func.func @complex_bitcast_i64(%arg0 : i64) {
-  // expected-error @+1 {{op requires input or output is a complex type}}
+  // expected-error @+1 {{op requires that either input or output has a complex type}}
   %0 = complex.bitcast %arg0: i64 to f64
   return
 }

@matthias-springer
Copy link
Member Author

Should we fix "bugs" like this one? Is it actually bug? I think there is at the moment no requirement that the IR has to verify after each pattern application.

I was looking into this because I had to debug a pass that applies multiple patterns and I wanted to see how an op was getting simplified. So I was running with -debug. Then I saw invalid IR half way through the process and I thought "if the IR already broken at this point, I don't even have to look further".

@matthias-springer matthias-springer merged commit 192439d into llvm:main Dec 5, 2023
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:complex MLIR complex dialect mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants