Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ring of springs integration test using modulo #1521

Merged
merged 4 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 50 additions & 46 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
B.SetInsertPoint(res);

if (auto PT = dyn_cast<PointerType>(storePointer->getType())) {
(void)PT;
#if LLVM_VERSION_MAJOR < 17
#if LLVM_VERSION_MAJOR >= 15
if (PT->getContext().supportsTypedPointers()) {
Expand Down Expand Up @@ -4329,56 +4330,58 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
}
// phi (idx=0) ? b, a, a -> select (idx == 0), b, a
if (auto L = LI.getLoopFor(PN->getParent()))
if (auto idx = L->getCanonicalInductionVariable())
if (auto PH = L->getLoopPreheader()) {
bool legal = idx != PN;
auto ph_idx = PN->getBasicBlockIndex(PH);
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
if ((int)i == ph_idx)
continue;
auto v = PN->getIncomingValue(i);
if (v != PN->getIncomingValue(1 - ph_idx)) {
legal = false;
break;
}
// The given var must dominate the loop
if (isa<Constant>(v))
continue;
if (isa<Argument>(v))
continue;
// exception for the induction itself, which we handle specially
if (v == idx)
continue;
auto I = cast<Instruction>(v);
if (!DT.dominates(I, PN)) {
legal = false;
break;
}
}
if (legal) {
auto val = PN->getIncomingValue(1 - ph_idx);
push(val);
if (val == idx) {
val = pushcse(
B.CreateSub(idx, ConstantInt::get(idx->getType(), 1)));
if (L->getHeader() == PN->getParent())
if (auto idx = L->getCanonicalInductionVariable())
if (auto PH = L->getLoopPreheader()) {
bool legal = idx != PN;
auto ph_idx = PN->getBasicBlockIndex(PH);
assert(ph_idx >= 0);
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
if ((int)i == ph_idx)
continue;
auto v = PN->getIncomingValue(i);
if (v != PN->getIncomingValue(1 - ph_idx)) {
legal = false;
break;
}
// The given var must dominate the loop
if (isa<Constant>(v))
continue;
if (isa<Argument>(v))
continue;
// exception for the induction itself, which we handle specially
if (v == idx)
continue;
auto I = cast<Instruction>(v);
if (!DT.dominates(I, PN)) {
legal = false;
break;
}
}
if (legal) {
auto val = PN->getIncomingValue(1 - ph_idx);
push(val);
if (val == idx) {
val = pushcse(
B.CreateSub(idx, ConstantInt::get(idx->getType(), 1)));
}

auto val2 = PN->getIncomingValue(ph_idx);
push(val2);
auto val2 = PN->getIncomingValue(ph_idx);
push(val2);

auto c0 = ConstantInt::get(idx->getType(), 0);
// if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) {
// val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val);
//} else {
auto eq = pushcse(B.CreateICmpEQ(idx, c0));
val = pushcse(
B.CreateSelect(eq, val2, val, "phisel." + cur->getName()));
//}
auto c0 = ConstantInt::get(idx->getType(), 0);
// if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) {
// val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val);
//} else {
auto eq = pushcse(B.CreateICmpEQ(idx, c0));
val = pushcse(
B.CreateSelect(eq, val2, val, "phisel." + cur->getName()));
//}

replaceAndErase(cur, val);
return "PhiLoop0Sel";
replaceAndErase(cur, val);
return "PhiLoop0Sel";
}
}
}
// phi (sitofp a), (sitofp b) -> sitofp (phi a, b)
{
SmallVector<Value *, 1> negOps;
Expand Down Expand Up @@ -4657,7 +4660,8 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
legal = false;
auto L = LI.getLoopFor(PN->getParent());
if (legal && L && L->getLoopPreheader() &&
L->getCanonicalInductionVariable()) {
L->getCanonicalInductionVariable() &&
L->getHeader() == PN->getParent()) {
auto ph_idx = PN->getBasicBlockIndex(L->getLoopPreheader());
if (isa<ConstantInt>(PN->getIncomingValue(ph_idx))) {
lhsOps[ph_idx] =
Expand Down
49 changes: 42 additions & 7 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9158,14 +9158,15 @@ bool GradientUtils::needsCacheWholeAllocation(
return false;
if (!found->second)
return true;
SmallVector<std::pair<const Instruction *, size_t>, 1> todo;
// User, operand of input, whehter the input is the original allocation
SmallVector<std::tuple<const Instruction *, size_t, bool>, 1> todo;
for (auto &use : origInst->uses())
todo.push_back(
std::make_pair(cast<Instruction>(use.getUser()), use.getOperandNo()));
SmallSet<std::pair<const Instruction *, size_t>, 1> seen;
todo.push_back(std::make_tuple(cast<Instruction>(use.getUser()),
use.getOperandNo(), true));
SmallSet<std::tuple<const Instruction *, size_t, bool>, 1> seen;
while (todo.size()) {
auto pair = todo.back();
auto [cur, idx] = pair;
auto [cur, idx, orig] = pair;
todo.pop_back();
if (seen.count(pair))
continue;
Expand All @@ -9184,6 +9185,8 @@ bool GradientUtils::needsCacheWholeAllocation(
II->getIntrinsicID() == Intrinsic::masked_load)
continue;

bool returnedSameValue = false;

if (auto CI = dyn_cast<CallInst>(cur)) {
#if LLVM_VERSION_MAJOR >= 14
if (idx < CI->arg_size())
Expand All @@ -9193,6 +9196,36 @@ bool GradientUtils::needsCacheWholeAllocation(
{
if (isNoCapture(CI, idx))
continue;

if (auto F = CI->getCalledFunction())
if (F->getCallingConv() == CI->getCallingConv()) {
bool onlyReturnUses = true;
bool hasReturnUse = true;

for (auto u : F->getArg(idx)->users()) {
if (isa<ReturnInst>(u)) {
hasReturnUse = true;
continue;
}
onlyReturnUses = false;
continue;
}
// The arg itself has no use in the function
if (onlyReturnUses && !hasReturnUse)
continue;

// If this is the original allocation, we return it guaranteed, and
// cache the return, that's still fine
if (onlyReturnUses && orig) {
found = knownRecomputeHeuristic.find(cur);
if (found == knownRecomputeHeuristic.end())
continue;

if (!found->second)
continue;
returnedSameValue = true;
}
}
}
}

Expand All @@ -9202,6 +9235,7 @@ bool GradientUtils::needsCacheWholeAllocation(

// If caching this user, it cannot be a gep/cast of original
if (!found->second) {
llvm::errs() << " mod: " << *oldFunc->getParent() << "\n";
llvm::errs() << " oldFunc: " << *oldFunc << "\n";
for (auto &pair : knownRecomputeHeuristic)
llvm::errs() << " krc[" << *pair.first << "] = " << pair.second << "\n";
Expand All @@ -9211,8 +9245,9 @@ bool GradientUtils::needsCacheWholeAllocation(
} else {
// if not caching this user, it is legal to recompute, consider its users
for (auto &use : cur->uses()) {
todo.push_back(std::make_pair(cast<Instruction>(use.getUser()),
use.getOperandNo()));
todo.push_back(std::make_tuple(cast<Instruction>(use.getUser()),
use.getOperandNo(),
returnedSameValue && orig));
}
}
}
Expand Down
136 changes: 136 additions & 0 deletions enzyme/test/Integration/Sparse/ringspring.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi

// everything should be always inline


#include <stdio.h>
#include <assert.h>
#include <vector>


#include<math.h>

struct triple {
size_t row;
size_t col;
double val;
triple(triple&&) = default;
triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {}
};


size_t N = 8;

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
extern int enzyme_const;

extern void __enzyme_autodiff(void *, ...);

extern void __enzyme_fwddiff(void *, ...);

extern double* __enzyme_todense(void *, ...) noexcept;


/// Compute energy
double f(size_t N, double* input) {
double out = 0;
// __builtin_assume(!((N-1) == 0));
for (size_t i=0; i<N; i++) {
//double sub = input[i] - input[i+1];
// out += sub * sub;
double sub = (input[i+1] - input[i]) * (input[i+1] - input[i]);
out += (sqrt(sub) - 1)*(sqrt(sub) - 1);
}
return out;
}

/// Perform dinput += gradient(f)
void grad_f(size_t N, double* input, double* dinput) {
__enzyme_autodiff((void*)f, enzyme_const, N, enzyme_dup, input, dinput);
}


void ident_store(double , int64_t idx, size_t i) {
assert(0 && "should never load");
}

__attribute__((always_inline))
double ident_load(int64_t idx, size_t i, size_t N) {
idx /= sizeof(double);
return (double)(idx % N == i % N);// ? 1.0 : 0.0;
martinjm97 marked this conversation as resolved.
Show resolved Hide resolved
}

__attribute__((enzyme_sparse_accumulate))
void inner_store(int64_t row, int64_t col, double val, std::vector<triple> &triplets) {
printf("row=%d col=%d val=%f\n", row, col % N, val);
// assert(abs(val) > 0.00001);
triplets.emplace_back(row % N, col % N, val);
}

__attribute__((always_inline))
void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector<triple> &triplets) {
if (val == 0.0) return;
idx /= sizeof(double);
inner_store(i, idx, val, triplets);
}

__attribute__((always_inline))
double sparse_load(int64_t idx, size_t i, size_t N, std::vector<triple> &triplets) {
return 0.0;
}

__attribute__((always_inline))
void never_store(double val, int64_t idx, double* input, size_t N) {
assert(0 && "this is a read only input, why are you storing here...");
}

__attribute__((always_inline))
double mod_load(int64_t idx, double* input, size_t N) {
idx /= sizeof(double);
return input[idx % N];
}

__attribute__((noinline))
std::vector<triple> hess_f(size_t N, double* input) {
std::vector<triple> triplets;
input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N);
__builtin_assume(N > 0);
for (size_t i=0; i<N; i++) {
__builtin_assume(i < 100000000);
double* d_input = __enzyme_todense((void*)ident_load, (void*)ident_store, i, N);
double* d_dinput = __enzyme_todense((void*)sparse_load, (void*)sparse_store, i, N, &triplets);

__enzyme_fwddiff((void*)grad_f,
enzyme_const, N,
enzyme_dup, input, d_input,
enzyme_dupnoneed, (double*)0x1, d_dinput);

}
return triplets;
}


int main() {
// size_t N = 8;
double x[N];
for (int i=0; i<N; i++) x[i] = (i + 1) * (i + 1);

auto res = hess_f(N, &x[0]);


printf("%ld\n", res.size());

for (auto & tup : res)
printf("%ld, %ld = %f\n", tup.row, tup.col, tup.val);

return 0;
}
Loading