Skip to content

Commit

Permalink
Ring of springs integration test using modulo (#1521)
Browse files Browse the repository at this point in the history
* Ring of springs integration test using modulo

* new bugs

* add phi fix

* Additional phi fix

---------

Co-authored-by: William S. Moses <gh@wsmoses.com>
  • Loading branch information
martinjm97 and wsmoses authored Nov 8, 2023
1 parent 35614bf commit 82318b8
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 53 deletions.
96 changes: 50 additions & 46 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
B.SetInsertPoint(res);

if (auto PT = dyn_cast<PointerType>(storePointer->getType())) {
(void)PT;
#if LLVM_VERSION_MAJOR < 17
#if LLVM_VERSION_MAJOR >= 15
if (PT->getContext().supportsTypedPointers()) {
Expand Down Expand Up @@ -4329,56 +4330,58 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
}
// phi (idx=0) ? b, a, a -> select (idx == 0), b, a
if (auto L = LI.getLoopFor(PN->getParent()))
if (auto idx = L->getCanonicalInductionVariable())
if (auto PH = L->getLoopPreheader()) {
bool legal = idx != PN;
auto ph_idx = PN->getBasicBlockIndex(PH);
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
if ((int)i == ph_idx)
continue;
auto v = PN->getIncomingValue(i);
if (v != PN->getIncomingValue(1 - ph_idx)) {
legal = false;
break;
}
// The given var must dominate the loop
if (isa<Constant>(v))
continue;
if (isa<Argument>(v))
continue;
// exception for the induction itself, which we handle specially
if (v == idx)
continue;
auto I = cast<Instruction>(v);
if (!DT.dominates(I, PN)) {
legal = false;
break;
}
}
if (legal) {
auto val = PN->getIncomingValue(1 - ph_idx);
push(val);
if (val == idx) {
val = pushcse(
B.CreateSub(idx, ConstantInt::get(idx->getType(), 1)));
if (L->getHeader() == PN->getParent())
if (auto idx = L->getCanonicalInductionVariable())
if (auto PH = L->getLoopPreheader()) {
bool legal = idx != PN;
auto ph_idx = PN->getBasicBlockIndex(PH);
assert(ph_idx >= 0);
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
if ((int)i == ph_idx)
continue;
auto v = PN->getIncomingValue(i);
if (v != PN->getIncomingValue(1 - ph_idx)) {
legal = false;
break;
}
// The given var must dominate the loop
if (isa<Constant>(v))
continue;
if (isa<Argument>(v))
continue;
// exception for the induction itself, which we handle specially
if (v == idx)
continue;
auto I = cast<Instruction>(v);
if (!DT.dominates(I, PN)) {
legal = false;
break;
}
}
if (legal) {
auto val = PN->getIncomingValue(1 - ph_idx);
push(val);
if (val == idx) {
val = pushcse(
B.CreateSub(idx, ConstantInt::get(idx->getType(), 1)));
}

auto val2 = PN->getIncomingValue(ph_idx);
push(val2);
auto val2 = PN->getIncomingValue(ph_idx);
push(val2);

auto c0 = ConstantInt::get(idx->getType(), 0);
// if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) {
// val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val);
//} else {
auto eq = pushcse(B.CreateICmpEQ(idx, c0));
val = pushcse(
B.CreateSelect(eq, val2, val, "phisel." + cur->getName()));
//}
auto c0 = ConstantInt::get(idx->getType(), 0);
// if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) {
// val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val);
//} else {
auto eq = pushcse(B.CreateICmpEQ(idx, c0));
val = pushcse(
B.CreateSelect(eq, val2, val, "phisel." + cur->getName()));
//}

replaceAndErase(cur, val);
return "PhiLoop0Sel";
replaceAndErase(cur, val);
return "PhiLoop0Sel";
}
}
}
// phi (sitofp a), (sitofp b) -> sitofp (phi a, b)
{
SmallVector<Value *, 1> negOps;
Expand Down Expand Up @@ -4657,7 +4660,8 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
legal = false;
auto L = LI.getLoopFor(PN->getParent());
if (legal && L && L->getLoopPreheader() &&
L->getCanonicalInductionVariable()) {
L->getCanonicalInductionVariable() &&
L->getHeader() == PN->getParent()) {
auto ph_idx = PN->getBasicBlockIndex(L->getLoopPreheader());
if (isa<ConstantInt>(PN->getIncomingValue(ph_idx))) {
lhsOps[ph_idx] =
Expand Down
49 changes: 42 additions & 7 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9158,14 +9158,15 @@ bool GradientUtils::needsCacheWholeAllocation(
return false;
if (!found->second)
return true;
SmallVector<std::pair<const Instruction *, size_t>, 1> todo;
// User, operand of input, whehter the input is the original allocation
SmallVector<std::tuple<const Instruction *, size_t, bool>, 1> todo;
for (auto &use : origInst->uses())
todo.push_back(
std::make_pair(cast<Instruction>(use.getUser()), use.getOperandNo()));
SmallSet<std::pair<const Instruction *, size_t>, 1> seen;
todo.push_back(std::make_tuple(cast<Instruction>(use.getUser()),
use.getOperandNo(), true));
SmallSet<std::tuple<const Instruction *, size_t, bool>, 1> seen;
while (todo.size()) {
auto pair = todo.back();
auto [cur, idx] = pair;
auto [cur, idx, orig] = pair;
todo.pop_back();
if (seen.count(pair))
continue;
Expand All @@ -9184,6 +9185,8 @@ bool GradientUtils::needsCacheWholeAllocation(
II->getIntrinsicID() == Intrinsic::masked_load)
continue;

bool returnedSameValue = false;

if (auto CI = dyn_cast<CallInst>(cur)) {
#if LLVM_VERSION_MAJOR >= 14
if (idx < CI->arg_size())
Expand All @@ -9193,6 +9196,36 @@ bool GradientUtils::needsCacheWholeAllocation(
{
if (isNoCapture(CI, idx))
continue;

if (auto F = CI->getCalledFunction())
if (F->getCallingConv() == CI->getCallingConv()) {
bool onlyReturnUses = true;
bool hasReturnUse = true;

for (auto u : F->getArg(idx)->users()) {
if (isa<ReturnInst>(u)) {
hasReturnUse = true;
continue;
}
onlyReturnUses = false;
continue;
}
// The arg itself has no use in the function
if (onlyReturnUses && !hasReturnUse)
continue;

// If this is the original allocation, we return it guaranteed, and
// cache the return, that's still fine
if (onlyReturnUses && orig) {
found = knownRecomputeHeuristic.find(cur);
if (found == knownRecomputeHeuristic.end())
continue;

if (!found->second)
continue;
returnedSameValue = true;
}
}
}
}

Expand All @@ -9202,6 +9235,7 @@ bool GradientUtils::needsCacheWholeAllocation(

// If caching this user, it cannot be a gep/cast of original
if (!found->second) {
llvm::errs() << " mod: " << *oldFunc->getParent() << "\n";
llvm::errs() << " oldFunc: " << *oldFunc << "\n";
for (auto &pair : knownRecomputeHeuristic)
llvm::errs() << " krc[" << *pair.first << "] = " << pair.second << "\n";
Expand All @@ -9211,8 +9245,9 @@ bool GradientUtils::needsCacheWholeAllocation(
} else {
// if not caching this user, it is legal to recompute, consider its users
for (auto &use : cur->uses()) {
todo.push_back(std::make_pair(cast<Instruction>(use.getUser()),
use.getOperandNo()));
todo.push_back(std::make_tuple(cast<Instruction>(use.getUser()),
use.getOperandNo(),
returnedSameValue && orig));
}
}
}
Expand Down
136 changes: 136 additions & 0 deletions enzyme/test/Integration/Sparse/ringspring.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi

// everything should be always inline


#include <stdio.h>
#include <assert.h>
#include <vector>


#include<math.h>

struct triple {
size_t row;
size_t col;
double val;
triple(triple&&) = default;
triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {}
};


size_t N = 8;

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
extern int enzyme_const;

extern void __enzyme_autodiff(void *, ...);

extern void __enzyme_fwddiff(void *, ...);

extern double* __enzyme_todense(void *, ...) noexcept;


/// Compute energy
double f(size_t N, double* input) {
double out = 0;
// __builtin_assume(!((N-1) == 0));
for (size_t i=0; i<N; i++) {
//double sub = input[i] - input[i+1];
// out += sub * sub;
double sub = (input[i+1] - input[i]) * (input[i+1] - input[i]);
out += (sqrt(sub) - 1)*(sqrt(sub) - 1);
}
return out;
}

/// Perform dinput += gradient(f)
void grad_f(size_t N, double* input, double* dinput) {
__enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput);
}


void ident_store(double , int64_t idx, size_t i) {
assert(0 && "should never load");
}

__attribute__((always_inline))
double ident_load(int64_t idx, size_t i, size_t N) {
idx /= sizeof(double);
return (double)(idx % N == i % N);// ? 1.0 : 0.0;
}

__attribute__((enzyme_sparse_accumulate))
void inner_store(int64_t row, int64_t col, double val, std::vector<triple> &triplets) {
printf("row=%d col=%d val=%f\n", row, col % N, val);
// assert(abs(val) > 0.00001);
triplets.emplace_back(row % N, col % N, val);
}

__attribute__((always_inline))
void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector<triple> &triplets) {
if (val == 0.0) return;
idx /= sizeof(double);
inner_store(i, idx, val, triplets);
}

__attribute__((always_inline))
double sparse_load(int64_t idx, size_t i, size_t N, std::vector<triple> &triplets) {
return 0.0;
}

__attribute__((always_inline))
void never_store(double val, int64_t idx, double* input, size_t N) {
assert(0 && "this is a read only input, why are you storing here...");
}

__attribute__((always_inline))
double mod_load(int64_t idx, double* input, size_t N) {
idx /= sizeof(double);
return input[idx % N];
}

__attribute__((noinline))
std::vector<triple> hess_f(size_t N, double* input) {
std::vector<triple> triplets;
input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N);
__builtin_assume(N > 0);
for (size_t i=0; i<N; i++) {
__builtin_assume(i < 100000000);
double* d_input = __enzyme_todense((void*)ident_load, (void*)ident_store, i, N);
double* d_dinput = __enzyme_todense((void*)sparse_load, (void*)sparse_store, i, N, &triplets);

__enzyme_fwddiff((void*)grad_f,
enzyme_const, N,
enzyme_dup, input, d_input,
enzyme_dupnoneed, (double*)0x1, d_dinput);

}
return triplets;
}


int main() {
// size_t N = 8;
double x[N];
for (int i=0; i<N; i++) x[i] = (i + 1) * (i + 1);

auto res = hess_f(N, &x[0]);


printf("%ld\n", res.size());

for (auto & tup : res)
printf("%ld, %ld = %f\n", tup.row, tup.col, tup.val);

return 0;
}

0 comments on commit 82318b8

Please sign in to comment.