diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 7decbce018a878..eca13f52f53dc4 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -15,6 +15,7 @@ #include "mlir/Config/mlir-config.h" #include "mlir/IR/Action.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Verifier.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/FoldUtils.h" @@ -432,6 +433,10 @@ bool GreedyPatternRewriteDriver::processWorklist() { if (succeeded(folder.tryToFold(op))) { LLVM_DEBUG(logResultWithLine("success", "operation was folded")); changed = true; +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.scope && failed(verify(config.scope->getParentOp()))) + llvm::report_fatal_error("IR failed to verify after folding"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS continue; } @@ -464,8 +469,9 @@ bool GreedyPatternRewriteDriver::processWorklist() { #endif #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - debugFingerPrints.computeFingerPrints( - /*topLevel=*/config.scope ? config.scope->getParentOp() : op); + if (config.scope) { + debugFingerPrints.computeFingerPrints(config.scope->getParentOp()); + } auto clearFingerprints = llvm::make_scope_exit([&]() { debugFingerPrints.clear(); }); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS @@ -473,17 +479,24 @@ bool GreedyPatternRewriteDriver::processWorklist() { LogicalResult matchResult = matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.scope && failed(verify(config.scope->getParentOp()))) + llvm::report_fatal_error("IR failed to verify after pattern application"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (succeeded(matchResult)) { LLVM_DEBUG(logResultWithLine("success", "pattern matched")); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - debugFingerPrints.notifyRewriteSuccess(); + if (config.scope) + debugFingerPrints.notifyRewriteSuccess(); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS changed = true; ++numRewrites; } else { LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - debugFingerPrints.notifyRewriteFailure(); + if (config.scope) + debugFingerPrints.notifyRewriteFailure(); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS } } @@ -562,6 +575,18 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + +#ifndef NDEBUG + // Only ops that are within the configured scope are added to the worklist of + // the greedy pattern rewriter. Moreover, the parent op of the scope region is + // the part of the IR that is taken into account for the "expensive checks". + // A greedy pattern rewrite is not allowed to erase the parent op of the scope + // region, as that would break the worklist handling and the expensive checks. + if (config.scope && config.scope->getParentOp() == op) + llvm_unreachable( + "scope region must not be erased during greedy pattern rewrite"); +#endif // NDEBUG + if (config.listener) config.listener->notifyOperationRemoved(op); @@ -721,6 +746,12 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion, if (!config.scope) config.scope = ®ion; +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (failed(verify(config.scope->getParentOp()))) + llvm::report_fatal_error( + "greedy pattern rewriter input IR failed to verify"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Start the pattern driver. RegionPatternRewriteDriver driver(region.getContext(), patterns, config, region); @@ -846,6 +877,12 @@ LogicalResult mlir::applyOpPatternsAndFold( #endif // NDEBUG } +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.scope && failed(verify(config.scope->getParentOp()))) + llvm::report_fatal_error( + "greedy pattern rewriter input IR failed to verify"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Start the pattern driver. llvm::SmallDenseSet surviving; MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,