From b7718647049d4534ab9425a93565bc0fb8ff5d48 Mon Sep 17 00:00:00 2001 From: Jesse Michel Date: Thu, 2 Nov 2023 21:32:55 -0400 Subject: [PATCH 1/4] Ring of springs integration test using modulo --- enzyme/test/Integration/Sparse/ringspring.cpp | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 enzyme/test/Integration/Sparse/ringspring.cpp diff --git a/enzyme/test/Integration/Sparse/ringspring.cpp b/enzyme/test/Integration/Sparse/ringspring.cpp new file mode 100644 index 000000000000..ad54ec94228f --- /dev/null +++ b/enzyme/test/Integration/Sparse/ringspring.cpp @@ -0,0 +1,116 @@ +// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi + +#include +#include +#include + +#include + +struct triple { + size_t row; + size_t col; + double val; + triple(triple&&) = default; + triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {} +}; + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +extern void __enzyme_autodiff(void *, ...); + +extern void __enzyme_fwddiff(void *, ...); + +extern double* __enzyme_todense(void *, ...) noexcept; + + +/// Compute energy +double f(size_t N, double* input) { + double out = 0; + __builtin_assume(!((N-1) == 0)); + for (size_t i=0; i &triplets) { + printf("row=%d col=%d val=%f\n", row, col, val); + assert(abs(val) > 0.00001); + triplets.emplace_back(row, col, val); +} + +__attribute__((always_inline)) +void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { + if (val == 0.0) return; + idx /= sizeof(double); + inner_store(i, idx, val, triplets); +} + +__attribute__((always_inline)) +double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplets) { + return 0.0; +} + +__attribute__((noinline)) +std::vector hess_f(size_t N, double* input) { + std::vector triplets; + __builtin_assume(N > 0); + for (size_t i=0; i Date: Tue, 7 Nov 2023 16:32:39 -0500 Subject: [PATCH 2/4] new bugs --- enzyme/test/Integration/Sparse/ringspring.cpp | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/enzyme/test/Integration/Sparse/ringspring.cpp b/enzyme/test/Integration/Sparse/ringspring.cpp index ad54ec94228f..7c020a23e7b5 100644 --- a/enzyme/test/Integration/Sparse/ringspring.cpp +++ b/enzyme/test/Integration/Sparse/ringspring.cpp @@ -7,10 +7,14 @@ // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi +// everything should be always inline + + #include #include #include + #include struct triple { @@ -21,6 +25,9 @@ struct triple { triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {} }; + +size_t N = 8; + extern int enzyme_dup; extern int enzyme_dupnoneed; extern int enzyme_out; @@ -36,12 +43,12 @@ extern double* __enzyme_todense(void *, ...) noexcept; /// Compute energy double f(size_t N, double* input) { double out = 0; - __builtin_assume(!((N-1) == 0)); + // __builtin_assume(!((N-1) == 0)); for (size_t i=0; i &triplets) { - printf("row=%d col=%d val=%f\n", row, col, val); - assert(abs(val) > 0.00001); - triplets.emplace_back(row, col, val); + printf("row=%d col=%d val=%f\n", row, col % N, val); + // assert(abs(val) > 0.00001); + triplets.emplace_back(row % N, col % N, val); } __attribute__((always_inline)) @@ -81,9 +88,21 @@ double sparse_load(int64_t idx, size_t i, size_t N, std::vector &triplet return 0.0; } +__attribute__((always_inline)) +void never_store(double val, int64_t idx, double* input, size_t N) { + assert(0 && "this is a read only input, why are you storing here..."); +} + +__attribute__((always_inline)) +double mod_load(int64_t idx, double* input, size_t N) { + idx /= sizeof(double); + return input[idx % N]; +} + __attribute__((noinline)) std::vector hess_f(size_t N, double* input) { std::vector triplets; + input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); for (size_t i=0; i hess_f(size_t N, double* input) { return triplets; } + int main() { - size_t N = 8; + // size_t N = 8; double x[N]; for (int i=0; i Date: Tue, 7 Nov 2023 23:54:20 -0500 Subject: [PATCH 3/4] add phi fix --- enzyme/Enzyme/FunctionUtils.cpp | 92 +++++++++++++++++---------------- enzyme/Enzyme/GradientUtils.cpp | 49 +++++++++++++++--- 2 files changed, 89 insertions(+), 52 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 5abc0a3729f3..703e0d833a76 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -4329,56 +4329,58 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } // phi (idx=0) ? b, a, a -> select (idx == 0), b, a if (auto L = LI.getLoopFor(PN->getParent())) - if (auto idx = L->getCanonicalInductionVariable()) - if (auto PH = L->getLoopPreheader()) { - bool legal = idx != PN; - auto ph_idx = PN->getBasicBlockIndex(PH); - for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { - if ((int)i == ph_idx) - continue; - auto v = PN->getIncomingValue(i); - if (v != PN->getIncomingValue(1 - ph_idx)) { - legal = false; - break; - } - // The given var must dominate the loop - if (isa(v)) - continue; - if (isa(v)) - continue; - // exception for the induction itself, which we handle specially - if (v == idx) - continue; - auto I = cast(v); - if (!DT.dominates(I, PN)) { - legal = false; - break; - } - } - if (legal) { - auto val = PN->getIncomingValue(1 - ph_idx); - push(val); - if (val == idx) { - val = pushcse( - B.CreateSub(idx, ConstantInt::get(idx->getType(), 1))); + if (L->getHeader() == PN->getParent()) + if (auto idx = L->getCanonicalInductionVariable()) + if (auto PH = L->getLoopPreheader()) { + bool legal = idx != PN; + auto ph_idx = PN->getBasicBlockIndex(PH); + assert(ph_idx >= 0); + for (size_t i = 0; i < PN->getNumIncomingValues(); i++) { + if ((int)i == ph_idx) + continue; + auto v = PN->getIncomingValue(i); + if (v != PN->getIncomingValue(1 - ph_idx)) { + legal = false; + break; + } + // The given var must dominate the loop + if (isa(v)) + continue; + if (isa(v)) + continue; + // exception for the induction itself, which we handle specially + if (v == idx) + continue; + auto I = cast(v); + if (!DT.dominates(I, PN)) { + legal = false; + break; + } } + if (legal) { + auto val = PN->getIncomingValue(1 - ph_idx); + push(val); + if (val == idx) { + val = pushcse( + B.CreateSub(idx, ConstantInt::get(idx->getType(), 1))); + } - auto val2 = PN->getIncomingValue(ph_idx); - push(val2); + auto val2 = PN->getIncomingValue(ph_idx); + push(val2); - auto c0 = ConstantInt::get(idx->getType(), 0); - // if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) { - // val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val); - //} else { - auto eq = pushcse(B.CreateICmpEQ(idx, c0)); - val = pushcse( - B.CreateSelect(eq, val2, val, "phisel." + cur->getName())); - //} + auto c0 = ConstantInt::get(idx->getType(), 0); + // if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) { + // val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val); + //} else { + auto eq = pushcse(B.CreateICmpEQ(idx, c0)); + val = pushcse( + B.CreateSelect(eq, val2, val, "phisel." + cur->getName())); + //} - replaceAndErase(cur, val); - return "PhiLoop0Sel"; + replaceAndErase(cur, val); + return "PhiLoop0Sel"; + } } - } // phi (sitofp a), (sitofp b) -> sitofp (phi a, b) { SmallVector negOps; diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index ca5996eca9ce..d52a5f80f577 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -9158,14 +9158,15 @@ bool GradientUtils::needsCacheWholeAllocation( return false; if (!found->second) return true; - SmallVector, 1> todo; + // User, operand of input, whehter the input is the original allocation + SmallVector, 1> todo; for (auto &use : origInst->uses()) - todo.push_back( - std::make_pair(cast(use.getUser()), use.getOperandNo())); - SmallSet, 1> seen; + todo.push_back(std::make_tuple(cast(use.getUser()), + use.getOperandNo(), true)); + SmallSet, 1> seen; while (todo.size()) { auto pair = todo.back(); - auto [cur, idx] = pair; + auto [cur, idx, orig] = pair; todo.pop_back(); if (seen.count(pair)) continue; @@ -9184,6 +9185,8 @@ bool GradientUtils::needsCacheWholeAllocation( II->getIntrinsicID() == Intrinsic::masked_load) continue; + bool returnedSameValue = false; + if (auto CI = dyn_cast(cur)) { #if LLVM_VERSION_MAJOR >= 14 if (idx < CI->arg_size()) @@ -9193,6 +9196,36 @@ bool GradientUtils::needsCacheWholeAllocation( { if (isNoCapture(CI, idx)) continue; + + if (auto F = CI->getCalledFunction()) + if (F->getCallingConv() == CI->getCallingConv()) { + bool onlyReturnUses = true; + bool hasReturnUse = true; + + for (auto u : F->getArg(idx)->users()) { + if (isa(u)) { + hasReturnUse = true; + continue; + } + onlyReturnUses = false; + continue; + } + // The arg itself has no use in the function + if (onlyReturnUses && !hasReturnUse) + continue; + + // If this is the original allocation, we return it guaranteed, and + // cache the return, that's still fine + if (onlyReturnUses && orig) { + found = knownRecomputeHeuristic.find(cur); + if (found == knownRecomputeHeuristic.end()) + continue; + + if (!found->second) + continue; + returnedSameValue = true; + } + } } } @@ -9202,6 +9235,7 @@ bool GradientUtils::needsCacheWholeAllocation( // If caching this user, it cannot be a gep/cast of original if (!found->second) { + llvm::errs() << " mod: " << *oldFunc->getParent() << "\n"; llvm::errs() << " oldFunc: " << *oldFunc << "\n"; for (auto &pair : knownRecomputeHeuristic) llvm::errs() << " krc[" << *pair.first << "] = " << pair.second << "\n"; @@ -9211,8 +9245,9 @@ bool GradientUtils::needsCacheWholeAllocation( } else { // if not caching this user, it is legal to recompute, consider its users for (auto &use : cur->uses()) { - todo.push_back(std::make_pair(cast(use.getUser()), - use.getOperandNo())); + todo.push_back(std::make_tuple(cast(use.getUser()), + use.getOperandNo(), + returnedSameValue && orig)); } } } From c91e52400679a6604290f436d739d0aa4ef7d5ab Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 7 Nov 2023 23:59:15 -0500 Subject: [PATCH 4/4] Additional phi fix --- enzyme/Enzyme/FunctionUtils.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 703e0d833a76..f4077d8e3fde 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -1047,6 +1047,7 @@ static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) { B.SetInsertPoint(res); if (auto PT = dyn_cast(storePointer->getType())) { + (void)PT; #if LLVM_VERSION_MAJOR < 17 #if LLVM_VERSION_MAJOR >= 15 if (PT->getContext().supportsTypedPointers()) { @@ -4659,7 +4660,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, legal = false; auto L = LI.getLoopFor(PN->getParent()); if (legal && L && L->getLoopPreheader() && - L->getCanonicalInductionVariable()) { + L->getCanonicalInductionVariable() && + L->getHeader() == PN->getParent()) { auto ph_idx = PN->getBasicBlockIndex(L->getLoopPreheader()); if (isa(PN->getIncomingValue(ph_idx))) { lhsOps[ph_idx] =