Skip to content

Commit

Permalink
fixed memory issue
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 8, 2024
1 parent 75d9bc0 commit e940f5a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
53 changes: 42 additions & 11 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2606,21 +2606,44 @@ static bool isNot(Value *a, Value *b) {

struct compare_insts {
public:
DominatorTree &DT;
LoopInfo &LI;
compare_insts(DominatorTree &DT, LoopInfo &LI) : DT(DT), LI(LI) {}
// return true if B appears later than A.
bool operator()(Instruction * A, Instruction *B) const {
if (A->getParent() != B->getParent())
return A < B;
return A->comesBefore(B);
if (A->getParent() == B->getParent()) {
return !A->comesBefore(B);
}
auto AB = A->getParent();
auto BB = B->getParent();
assert(AB->getParent() == BB->getParent());
if (DT.dominates(AB, BB))
return false;
if (!DT.dominates(BB, AB))
return true;
for (auto prev = BB->getPrevNode(); prev; prev = prev->getPrevNode()) {
if (prev == AB)
return true;
}
return false;
}
};

class DominatorOrderSet : public std::set<Instruction*, compare_insts> {
public:
DominatorOrderSet() : std::set<Instruction*, compare_insts>(compare_insts{}) {}
DominatorOrderSet(std::set<Instruction*, compare_insts>::iterator begin, std::set<Instruction*, compare_insts>::iterator end) : std::set<Instruction*, compare_insts>(begin, end, compare_insts{}) {}
DominatorOrderSet(DominatorTree &DT, LoopInfo &LI) : std::set<Instruction*, compare_insts>(compare_insts(DT, LI)) {}
bool contains(Instruction* I) const { return count(I) != 0; }
void remove(Instruction* I) { erase(I); }
Instruction* pop_back_val() {
auto back = end();
back--;
auto v = *back;
erase(back);
return v;
}
};

typedef llvm::SetVector<Instruction *, llvm::SmallVector<Instruction *, 1>,
DominatorOrderSet> QueueType;
typedef DominatorOrderSet QueueType;

std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
QueueType &Q,
Expand Down Expand Up @@ -2686,6 +2709,12 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
operands.insert(I2);
}
}
if (Q.contains(I)) {
Q.remove(I);
}
for (auto q : Q)
llvm::errs() << " -- q: " << *q << "\n";
llvm::errs() << "erasing I: " << *I << "\n";
I->eraseFromParent();
for (auto op : operands)
if (op->getNumUses() == 0) {
Expand All @@ -2700,6 +2729,7 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
if (cur->getNumUses() == 0) {
for (size_t i = 0; i < cur->getNumOperands(); i++)
push(cur->getOperand(i));
assert(!Q.contains(cur));
cur->eraseFromParent();
return "DCE";
}
Expand Down Expand Up @@ -4142,7 +4172,7 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
Value *b = cur->getOperand(1 - i);

// fmul (fmul x:constant, y):z, b:constant .
if (auto C = dyn_cast<Constant>(b))
if (isa<Constant>(b))
if (auto z = dyn_cast<BinaryOperator>(prelhs)) {
if (z->getOpcode() == Instruction::FMul) {
for (int j = 0; j < 2; j++) {
Expand Down Expand Up @@ -5509,7 +5539,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
auto &LI = FAM.getResult<LoopAnalysis>(F);
auto &DL = F.getParent()->getDataLayout();

QueueType Q;
QueueType Q(DT, LI);
{
llvm::SetVector<BasicBlock *> todoBlocks;
for (auto b : toDenseBlocks) {
Expand All @@ -5530,15 +5560,16 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
// Full simplification
while (!Q.empty()) {
auto cur = Q.pop_back_val();
QueueType prev(Q.begin(), Q.end());
std::set<Instruction*> prev;
for (auto v : Q) prev.insert(v);
llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n";
auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL);
(void)changed;
if (changed) {
llvm::errs() << "changed: " << *changed << "\n";

for (auto I : Q)
if (!prev.contains(I))
if (!prev.count(I))
llvm::errs() << " + " << *I << "\n";
llvm::errs() << F << "\n\n";
}
Expand Down
2 changes: 2 additions & 0 deletions enzyme/test/Integration/Sparse/sqrtspring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi

#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <assert.h>
#include <vector>
Expand Down

0 comments on commit e940f5a

Please sign in to comment.