Skip to content

Commit

Permalink
continued fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 15, 2024
1 parent 8b3d3af commit 402686a
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 39 deletions.
14 changes: 14 additions & 0 deletions enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,20 @@ void RemoveRedundantIVs(
// and must thus be expanded after all phi's
Value *NewIV =
Exp.expandCodeFor(S, Tmp->getType(), Header->getFirstNonPHI());

// Explicity preserve wrap behavior from original iv. This is necessary
// until this PR in llvm is merged:
// https://github.com/llvm/llvm-project/pull/78199
if (auto addrec = dyn_cast<SCEVAddRecExpr>(S)) {
if (addrec->getLoop()->getHeader() == Header) {
if (auto add_or_mul = dyn_cast<BinaryOperator>(NewIV)) {
if (addrec->getNoWrapFlags(llvm::SCEV::FlagNUW))
add_or_mul->setHasNoUnsignedWrap(true);
if (addrec->getNoWrapFlags(llvm::SCEV::FlagNSW))
add_or_mul->setHasNoSignedWrap(true);
}
}
}
replacer(Tmp, NewIV);
eraser(Tmp);
}
Expand Down
125 changes: 103 additions & 22 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2818,7 +2818,7 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
return "OrZero";
}
// or a, 1 -> 1
if (C->isOne()) {
if (C->isOne() && cur->getType()->isIntegerTy(1)) {
replaceAndErase(cur, C);
return "OrOne";
}
Expand All @@ -2829,7 +2829,7 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
for (int i = 0; i < 2; i++) {
if (auto C = dyn_cast<ConstantInt>(cur->getOperand(i))) {
// and a, 1 -> a
if (C->isOne()) {
if (C->isOne() && cur->getType()->isIntegerTy(1)) {
replaceAndErase(cur, cur->getOperand(1 - i));
return "AndOne";
}
Expand Down Expand Up @@ -3008,6 +3008,41 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
}
return val;
};

if (auto II = dyn_cast<IntrinsicInst>(cur))
if (II->getIntrinsicID() == Intrinsic::fmuladd || II->getIntrinsicID() == Intrinsic::fma) {
B.setFastMathFlags(getFast());
auto mul = pushcse(B.CreateFMul(II->getOperand(0), II->getOperand(1)));
auto add = pushcse(B.CreateFAdd(mul, II->getOperand(2)));
replaceAndErase(cur, add);
return "FMulAddExpand";
}

// (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if C2 divides C1
if ((cur->getOpcode() == Instruction::LShr || cur->getOpcode() == Instruction::SDiv || cur->getOpcode() == Instruction::UDiv) && cur->isExact())
if (auto C2 = dyn_cast<ConstantInt>(cur->getOperand(1)))
if (auto mul = dyn_cast<BinaryOperator>(cur->getOperand(0)))
if (mul->getOpcode() == Instruction::Mul)
for (int i0=0; i0<2; i0++)
if (auto C1 = dyn_cast<ConstantInt>(mul->getOperand(i0))) {
auto lhs = C1->getValue();
APInt rhs = C2->getValue();
if (cur->getOpcode() == Instruction::LShr) {
rhs = APInt(rhs.getBitWidth(), 1) << rhs;
}

APInt div, rem;
if (cur->getOpcode() == Instruction::LShr || cur->getOpcode() == Instruction::UDiv)
APInt::udivrem(lhs, rhs, div, rem);
else
APInt::sdivrem(lhs, rhs, div, rem);
if (rem.isZero()) {
auto res = B.CreateMul( mul->getOperand(1-i0), ConstantInt::get(cur->getType(), div), "mdiv." + cur->getName(), mul->hasNoUnsignedWrap(), mul->hasNoSignedWrap());
push(mul);
replaceAndErase(cur, res);
return "IMulDivConst";
}
}

// mul (mul a, const1), (mul b, const2) -> mul (mul a, b), (const1, const2)
if (cur->getOpcode() == Instruction::FMul)
Expand Down Expand Up @@ -3915,7 +3950,7 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,

// (a | b) == 0 -> a == 0 & b == 0
if (auto icmp = dyn_cast<ICmpInst>(cur))
if (icmp->getPredicate() == ICmpInst::ICMP_EQ)
if (icmp->getPredicate() == ICmpInst::ICMP_EQ && cur->getType()->isIntegerTy(1))
for (int i = 0; i < 2; i++)
if (auto C = dyn_cast<ConstantInt>(icmp->getOperand(i)))
if (C->isZero())
Expand Down Expand Up @@ -3995,7 +4030,9 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
SmallVector<Instruction *, 1> precasts;
Value *lhs = nullptr;

Value *prelhs = cur->getOperand(0);
Value *prelhs = (cur->getOpcode() == Instruction::FNeg) ?
ConstantFP::get(cur->getType(), 0.0) :
cur->getOperand(0);
Value *prerhs = (cur->getOpcode() == Instruction::FNeg)
? cur->getOperand(0)
: cur->getOperand(1);
Expand Down Expand Up @@ -4328,11 +4365,20 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
push(SI);
auto ntval = (tvalC && tvalC->isZero())
? tvalC
: pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur));
: pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur, "sfdiv2_t." + cur->getName()));
auto nfval =
(fvalC && fvalC->isZero())
? fvalC
: pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur));
: pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur, "sfdiv2_f." + cur->getName()));

// Work around bad fdivfmf, fixed in LLVM 16+
// https://github.com/llvm/llvm-project/commit/4f3b1c6dd6ef6c7b5bb79f058e3b7ba4bcdf4566
#if LLVM_VERSION_MAJOR < 16
for (auto v : {ntval, nfval})
if (auto I = dyn_cast<Instruction>(v))
I->setFastMathFlags(cur->getFastMathFlags());
#endif

auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval,
"sfdiv2." + cur->getName()));

Expand Down Expand Up @@ -5206,6 +5252,12 @@ bool cannotDependOnLoopIV(const SCEV *S, const Loop *L) {
return false;
return true;
}
if (auto M = dyn_cast<SCEVUDivExpr>(S)) {
for (auto o : M->operands())
if (!cannotDependOnLoopIV(o, L))
return false;
return true;
}
if (auto UV = dyn_cast<SCEVUnknown>(S)) {
auto U = UV->getValue();
if (isa<Argument>(U))
Expand Down Expand Up @@ -5602,6 +5654,7 @@ return true;
}

if (isEqual) {
if (Loop)
if (auto rep = evaluateAtLoopIter(rhs->node, ctx.SE, Loop, node))
if (rep != rhs->node) {
auto newrhs = make_compare(rep, rhs->isEqual, rhs->Loop, ctx);
Expand All @@ -5625,6 +5678,7 @@ return true;
}

if (rhs->isEqual) {
if (rhs->Loop)
if (auto rep = evaluateAtLoopIter(node, ctx.SE, rhs->Loop, rhs->node))
if (rep != node) {
auto newlhs = make_compare(rep, isEqual, Loop, ctx);
Expand Down Expand Up @@ -5760,11 +5814,13 @@ return true;
else
insert(ivVals, shared_from_this());

ConstraintContext ctxd(ctx, shared_from_this(), rhs);

for (const auto &iv : ivVals) {
SetTy nextunionVals;
bool midchanged = false;
for (auto &uv : unionVals) {
auto tmp = iv->andB(uv, ctx);
auto tmp = iv->andB(uv, ctxd);
if (!tmp) {
midchanged = false;
nextunionVals = unionVals;
Expand Down Expand Up @@ -5814,8 +5870,10 @@ return true;

if (changed) {
auto cur = Constraints::none();
for (auto uv : unionVals)
cur = cur->orB(uv, ctx);
for (auto uv : unionVals) {
cur = cur->orB(uv, ctxd);
if (!cur) break;
}

if (*cur != *rhs)
return andB(cur, ctx);
Expand Down Expand Up @@ -6017,11 +6075,9 @@ Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP,
if (isEqual) {
return {std::make_pair(Exp.expandCodeFor(node, T, IP), cond)};
}
llvm::errs() << *this << "\n";
llvm_unreachable("Constraint ne not handled");
// EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ",
// F, "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, "\n No
// sparsification: not sparse solvable(nosoltn): solutions:",*solutions);
EmitFailure("NoSparsification", IP->getDebugLoc(), IP, "Negated solution not handled: ", *this);
assert(0);
return {};
}
case Type::Union: {
SmallVector<std::pair<Value *, Value *>, 1> vals;
Expand All @@ -6031,21 +6087,36 @@ Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP,
return vals;
}
case Type::Intersect:{
{
SmallVector<InnerTy, 1> vals(values.begin(), values.end());
ssize_t unionidx = -1;
for (int i=0; i<vals.size(); i++) {
if (vals[i]->ty == Type::Union) {
unionidx = i;
bool allne = true;
for (auto &v : vals[i]->values) {
if (v->ty != Type::Compare ||
v->isEqual) {
allne = false;
break;
}
}
if (allne) break;
}
}
if (unionidx != -1) {
auto others = Constraints::all();
for (int j=0; j<vals.size(); j++)
if (i != j)
if (unionidx != j)
others = others->andB(vals[j], ctx);
SmallVector<std::pair<Value *, Value *>, 1> resvals;
for (auto &v : vals[i]->values) {
for (auto &v : vals[unionidx]->values) {
auto tmp = v->andB(others, ctx);
for (const auto& sol : tmp->allSolutions(Exp, T, IP, ctx, B))
resvals.push_back(sol);
}
return resvals;
}
}
}
Value *solVal = nullptr;
Value *cond = ConstantInt::getTrue(T->getContext());
Expand Down Expand Up @@ -6284,15 +6355,13 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
// Full simplification
while (!Q.empty()) {
auto cur = Q.pop_back_val();
/*
std::set<Instruction *> prev;
for (auto v : Q)
prev.insert(v);
llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n";
*/
auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL);
(void)changed;
/*

if (changed) {
llvm::errs() << "changed: " << *changed << "\n";

Expand All @@ -6301,9 +6370,10 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
llvm::errs() << " + " << *I << "\n";
llvm::errs() << F << "\n\n";
}
*/
}

llvm::errs() << " post fix inner " << F << "\n";

SmallVector<std::pair<BasicBlock *, BranchInst *>, 1> sparseBlocks;
bool legalToSparse = true;
for (auto &B : F)
Expand Down Expand Up @@ -6490,8 +6560,19 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
forSparsification[L].second.emplace_back(blk, solutions);
}

if (sawError)
if (sawError) {
for (auto & pair : forSparsification) {
for (auto PN : {pair.second.first.first, pair.second.first.second}) {
PN->replaceAllUsesWith(UndefValue::get(PN->getType()));
PN->eraseFromParent();
}
}
if (llvm::verifyFunction(F, &llvm::errs())) {
llvm::errs() << F << "\n";
report_fatal_error("function failed verification (6)");
}
return;
}

if (forSparsification.size() == 0) {
llvm::errs() << " found no stores for sparsification\n";
Expand Down
13 changes: 4 additions & 9 deletions enzyme/test/Integration/Sparse/ringspring2Dextenddata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <stdio.h>
#include <assert.h>
#include <vector>
#include <random>


#include<math.h>
Expand All @@ -26,8 +25,6 @@ struct triple {
};


size_t N;

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
Expand Down Expand Up @@ -73,7 +70,7 @@ double ident_load(size_t idx, size_t i, size_t N) {
}

__attribute__((enzyme_sparse_accumulate))
void inner_store(int64_t row, int64_t col, double val, std::vector<triple> &triplets) {
void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector<triple> &triplets) {
printf("row=%d col=%d val=%f\n", row, col % N, val);
// assert(abs(val) > 0.00001);
triplets.emplace_back(row % N, col % N, val);
Expand All @@ -83,7 +80,7 @@ __attribute__((always_inline))
void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector<triple> &triplets) {
if (val == 0.0) return;
idx /= sizeof(double);
inner_store(i, idx, val, triplets);
inner_store(i, idx, N, val, triplets);
}

__attribute__((always_inline))
Expand Down Expand Up @@ -132,8 +129,6 @@ std::vector<triple> hess_f2(size_t N, double* input) {
*/
// int argc, char** argv
int __attribute__((always_inline)) main() {
std::mt19937 generator(0); // Seed the random number generator
std::uniform_real_distribution<double> normal(0, 0.05);


// if (argc != 2) {
Expand All @@ -147,8 +142,8 @@ int __attribute__((always_inline)) main() {
double x[2 * N + 2];
for (int i = 0; i < N; ++i) {
double angle = 2 * M_PI * i / N;
x[2 * i] = cos(angle) + normal(generator);
x[2 * i + 1] = sin(angle) + normal(generator);
x[2 * i] = cos(angle) ;//+ normal(generator);
x[2 * i + 1] = sin(angle) ;//+ normal(generator);
}
x[2 * N] = x[0];
x[2 * N + 1] = x[1];
Expand Down
17 changes: 9 additions & 8 deletions enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
Expand All @@ -26,8 +26,6 @@ struct triple {
};


size_t N;

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
Expand All @@ -43,7 +41,9 @@ extern double* __enzyme_todense(void *, ...) noexcept;
__attribute__((always_inline))
static double f(size_t N, double* pos) {
double e = 0.;
for (size_t i = 0; i < N; i += 3) {
__builtin_assume(N != 0);
for (size_t i = 0; i < N; i+=3) {
__builtin_assume(i < 1000000000);
double vx = pos[i];
double vy = pos[i + 1];
double vz = pos[i + 2];
Expand Down Expand Up @@ -74,7 +74,7 @@ static double ident_load(int64_t idx, size_t i, size_t N) {
}

__attribute__((enzyme_sparse_accumulate))
static void inner_store(int64_t row, int64_t col, double val, std::vector<triple> &triplets) {
static void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector<triple> &triplets) {
printf("row=%d col=%d val=%f\n", row, col % N, val);
// assert(abs(val) > 0.00001);
triplets.emplace_back(row % N, col % N, val);
Expand All @@ -84,7 +84,7 @@ __attribute__((always_inline))
static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector<triple> &triplets) {
if (val == 0.0) return;
idx /= sizeof(double);
inner_store(i, idx, val, triplets);
inner_store(i, idx, N, val, triplets);
}

__attribute__((always_inline))
Expand All @@ -108,6 +108,7 @@ std::vector<triple> hess_f(size_t N, double* input) {
std::vector<triple> triplets;
// input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N);
__builtin_assume(N > 0);
__builtin_assume(N < 10000000000);
for (size_t i=0; i<N; i++) {
__builtin_assume(i < 100000000);
double* d_input = __enzyme_todense((void*)ident_load, (void*)ident_store, i, N);
Expand Down

0 comments on commit 402686a

Please sign in to comment.