diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp index 5d81e9b99814ed..90ff8ef3b497fe 100644 --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/Threading.h" #include "llvm/ADT/DenseMapInfoVariant.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/PrettyStackTrace.h" @@ -55,6 +56,7 @@ class OperationVerifier { private: using WorkItem = llvm::PointerUnion; + using WorkItemEntry = llvm::PointerIntPair; /// This verifier uses a DFS of the tree of operations/blocks. The method /// verifyOnEntrance is invoked when we visit a node for the first time, i.e. @@ -267,10 +269,9 @@ LogicalResult OperationVerifier::verifyOnExit(Operation &op) { /// Such ops are collected separately and verified inside /// verifyBlockPostChildren. LogicalResult OperationVerifier::verifyOperation(Operation &op) { - SmallVector worklist{{&op}}; - SmallPtrSet seen; + SmallVector worklist{{&op, false}}; while (!worklist.empty()) { - WorkItem top = worklist.back(); + WorkItemEntry &top = worklist.back(); auto visit = [](auto &&visitor, WorkItem w) { if (w.is()) @@ -278,24 +279,28 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) { return visitor(w.get()); }; - const bool isExit = !seen.insert(top).second; + const bool isExit = top.getInt(); + top.setInt(true); + auto item = top.getPointer(); + // 2nd visit of this work item ("exit"). if (isExit) { - worklist.pop_back(); - if (failed(visit( - [this](auto *workItem) { return verifyOnExit(*workItem); }, top))) + if (failed( + visit([this](auto *workItem) { return verifyOnExit(*workItem); }, + item))) return failure(); + worklist.pop_back(); continue; } // 1st visit of this work item ("entrance"). if (failed(visit( [this](auto *workItem) { return verifyOnEntrance(*workItem); }, - top))) + item))) return failure(); - if (top.is()) { - Block ¤tBlock = *top.get(); + if (item.is()) { + Block ¤tBlock = *item.get(); // Skip "isolated from above operations". for (Operation &o : llvm::reverse(currentBlock)) { if (o.getNumRegions() == 0 || @@ -305,7 +310,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) { continue; } - Operation ¤tOp = *top.get(); + Operation ¤tOp = *item.get(); if (verifyRecursively) for (Region ®ion : llvm::reverse(currentOp.getRegions())) for (Block &block : llvm::reverse(region))