diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 3193a088e215..ab97f02bb66b 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2606,21 +2606,44 @@ static bool isNot(Value *a, Value *b) { struct compare_insts { public: + DominatorTree &DT; + LoopInfo &LI; + compare_insts(DominatorTree &DT, LoopInfo &LI) : DT(DT), LI(LI) {} + // return true if B appears later than A. bool operator()(Instruction * A, Instruction *B) const { - if (A->getParent() != B->getParent()) - return A < B; - return A->comesBefore(B); + if (A->getParent() == B->getParent()) { + return !A->comesBefore(B); + } + auto AB = A->getParent(); + auto BB = B->getParent(); + assert(AB->getParent() == BB->getParent()); + if (DT.dominates(AB, BB)) + return false; + if (!DT.dominates(BB, AB)) + return true; + for (auto prev = BB->getPrevNode(); prev; prev = prev->getPrevNode()) { + if (prev == AB) + return true; + } + return false; } }; class DominatorOrderSet : public std::set { public: - DominatorOrderSet() : std::set(compare_insts{}) {} - DominatorOrderSet(std::set::iterator begin, std::set::iterator end) : std::set(begin, end, compare_insts{}) {} + DominatorOrderSet(DominatorTree &DT, LoopInfo &LI) : std::set(compare_insts(DT, LI)) {} + bool contains(Instruction* I) const { return count(I) != 0; } + void remove(Instruction* I) { erase(I); } + Instruction* pop_back_val() { + auto back = end(); + back--; + auto v = *back; + erase(back); + return v; + } }; -typedef llvm::SetVector, - DominatorOrderSet> QueueType; +typedef DominatorOrderSet QueueType; std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, QueueType &Q, @@ -2686,6 +2709,12 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, operands.insert(I2); } } + if (Q.contains(I)) { + Q.remove(I); + } + for (auto q : Q) + llvm::errs() << " -- q: " << *q << "\n"; + llvm::errs() << "erasing I: " << *I << "\n"; I->eraseFromParent(); for (auto op : operands) if (op->getNumUses() == 0) { @@ -2700,6 +2729,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (cur->getNumUses() == 0) { for (size_t i = 0; i < cur->getNumOperands(); i++) push(cur->getOperand(i)); + assert(!Q.contains(cur)); cur->eraseFromParent(); return "DCE"; } @@ -4142,7 +4172,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, Value *b = cur->getOperand(1 - i); // fmul (fmul x:constant, y):z, b:constant . - if (auto C = dyn_cast(b)) + if (isa(b)) if (auto z = dyn_cast(prelhs)) { if (z->getOpcode() == Instruction::FMul) { for (int j = 0; j < 2; j++) { @@ -5509,7 +5539,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, auto &LI = FAM.getResult(F); auto &DL = F.getParent()->getDataLayout(); - QueueType Q; + QueueType Q(DT, LI); { llvm::SetVector todoBlocks; for (auto b : toDenseBlocks) { @@ -5530,7 +5560,8 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, // Full simplification while (!Q.empty()) { auto cur = Q.pop_back_val(); - QueueType prev(Q.begin(), Q.end()); + std::set prev; + for (auto v : Q) prev.insert(v); llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n"; auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL); (void)changed; @@ -5538,7 +5569,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, llvm::errs() << "changed: " << *changed << "\n"; for (auto I : Q) - if (!prev.contains(I)) + if (!prev.count(I)) llvm::errs() << " + " << *I << "\n"; llvm::errs() << F << "\n\n"; } diff --git a/enzyme/test/Integration/Sparse/sqrtspring.cpp b/enzyme/test/Integration/Sparse/sqrtspring.cpp index fdec1f99cdde..40d244704e1f 100644 --- a/enzyme/test/Integration/Sparse/sqrtspring.cpp +++ b/enzyme/test/Integration/Sparse/sqrtspring.cpp @@ -7,6 +7,8 @@ // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +#include +#include #include #include #include