From 402686aada871d9a00020e5441016986f00da056 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 15 Jan 2024 15:51:50 -0500 Subject: [PATCH] continued fixes --- enzyme/Enzyme/CacheUtility.cpp | 14 ++ enzyme/Enzyme/FunctionUtils.cpp | 125 +++++++++++++++--- .../Sparse/ringspring2Dextenddata.cpp | 13 +- .../Sparse/ringspring3Dextenddata.cpp | 17 +-- 4 files changed, 130 insertions(+), 39 deletions(-) diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp index fd34be6c1146..fa07681bdc20 100644 --- a/enzyme/Enzyme/CacheUtility.cpp +++ b/enzyme/Enzyme/CacheUtility.cpp @@ -257,6 +257,20 @@ void RemoveRedundantIVs( // and must thus be expanded after all phi's Value *NewIV = Exp.expandCodeFor(S, Tmp->getType(), Header->getFirstNonPHI()); + + // Explicity preserve wrap behavior from original iv. This is necessary + // until this PR in llvm is merged: + // https://github.com/llvm/llvm-project/pull/78199 + if (auto addrec = dyn_cast(S)) { + if (addrec->getLoop()->getHeader() == Header) { + if (auto add_or_mul = dyn_cast(NewIV)) { + if (addrec->getNoWrapFlags(llvm::SCEV::FlagNUW)) + add_or_mul->setHasNoUnsignedWrap(true); + if (addrec->getNoWrapFlags(llvm::SCEV::FlagNSW)) + add_or_mul->setHasNoSignedWrap(true); + } + } + } replacer(Tmp, NewIV); eraser(Tmp); } diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 03093423fea3..c316500565fe 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2818,7 +2818,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, return "OrZero"; } // or a, 1 -> 1 - if (C->isOne()) { + if (C->isOne() && cur->getType()->isIntegerTy(1)) { replaceAndErase(cur, C); return "OrOne"; } @@ -2829,7 +2829,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, for (int i = 0; i < 2; i++) { if (auto C = dyn_cast(cur->getOperand(i))) { // and a, 1 -> a - if (C->isOne()) { + if (C->isOne() && cur->getType()->isIntegerTy(1)) { replaceAndErase(cur, cur->getOperand(1 - i)); return "AndOne"; } @@ -3008,6 +3008,41 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } return val; }; + + if (auto II = dyn_cast(cur)) + if (II->getIntrinsicID() == Intrinsic::fmuladd || II->getIntrinsicID() == Intrinsic::fma) { + B.setFastMathFlags(getFast()); + auto mul = pushcse(B.CreateFMul(II->getOperand(0), II->getOperand(1))); + auto add = pushcse(B.CreateFAdd(mul, II->getOperand(2))); + replaceAndErase(cur, add); + return "FMulAddExpand"; + } + + // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if C2 divides C1 + if ((cur->getOpcode() == Instruction::LShr || cur->getOpcode() == Instruction::SDiv || cur->getOpcode() == Instruction::UDiv) && cur->isExact()) + if (auto C2 = dyn_cast(cur->getOperand(1))) + if (auto mul = dyn_cast(cur->getOperand(0))) + if (mul->getOpcode() == Instruction::Mul) + for (int i0=0; i0<2; i0++) + if (auto C1 = dyn_cast(mul->getOperand(i0))) { + auto lhs = C1->getValue(); + APInt rhs = C2->getValue(); + if (cur->getOpcode() == Instruction::LShr) { + rhs = APInt(rhs.getBitWidth(), 1) << rhs; + } + + APInt div, rem; + if (cur->getOpcode() == Instruction::LShr || cur->getOpcode() == Instruction::UDiv) + APInt::udivrem(lhs, rhs, div, rem); + else + APInt::sdivrem(lhs, rhs, div, rem); + if (rem.isZero()) { + auto res = B.CreateMul( mul->getOperand(1-i0), ConstantInt::get(cur->getType(), div), "mdiv." + cur->getName(), mul->hasNoUnsignedWrap(), mul->hasNoSignedWrap()); + push(mul); + replaceAndErase(cur, res); + return "IMulDivConst"; + } + } // mul (mul a, const1), (mul b, const2) -> mul (mul a, b), (const1, const2) if (cur->getOpcode() == Instruction::FMul) @@ -3915,7 +3950,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, // (a | b) == 0 -> a == 0 & b == 0 if (auto icmp = dyn_cast(cur)) - if (icmp->getPredicate() == ICmpInst::ICMP_EQ) + if (icmp->getPredicate() == ICmpInst::ICMP_EQ && cur->getType()->isIntegerTy(1)) for (int i = 0; i < 2; i++) if (auto C = dyn_cast(icmp->getOperand(i))) if (C->isZero()) @@ -3995,7 +4030,9 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, SmallVector precasts; Value *lhs = nullptr; - Value *prelhs = cur->getOperand(0); + Value *prelhs = (cur->getOpcode() == Instruction::FNeg) ? + ConstantFP::get(cur->getType(), 0.0) : + cur->getOperand(0); Value *prerhs = (cur->getOpcode() == Instruction::FNeg) ? cur->getOperand(0) : cur->getOperand(1); @@ -4328,11 +4365,20 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, push(SI); auto ntval = (tvalC && tvalC->isZero()) ? tvalC - : pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur)); + : pushcse(B.CreateFDivFMF(SI->getTrueValue(), b, cur, "sfdiv2_t." + cur->getName())); auto nfval = (fvalC && fvalC->isZero()) ? fvalC - : pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur)); + : pushcse(B.CreateFDivFMF(SI->getFalseValue(), b, cur, "sfdiv2_f." + cur->getName())); + + // Work around bad fdivfmf, fixed in LLVM 16+ + // https://github.com/llvm/llvm-project/commit/4f3b1c6dd6ef6c7b5bb79f058e3b7ba4bcdf4566 +#if LLVM_VERSION_MAJOR < 16 + for (auto v : {ntval, nfval}) + if (auto I = dyn_cast(v)) + I->setFastMathFlags(cur->getFastMathFlags()); +#endif + auto res = pushcse(B.CreateSelect(SI->getCondition(), ntval, nfval, "sfdiv2." + cur->getName())); @@ -5206,6 +5252,12 @@ bool cannotDependOnLoopIV(const SCEV *S, const Loop *L) { return false; return true; } + if (auto M = dyn_cast(S)) { + for (auto o : M->operands()) + if (!cannotDependOnLoopIV(o, L)) + return false; + return true; + } if (auto UV = dyn_cast(S)) { auto U = UV->getValue(); if (isa(U)) @@ -5602,6 +5654,7 @@ return true; } if (isEqual) { + if (Loop) if (auto rep = evaluateAtLoopIter(rhs->node, ctx.SE, Loop, node)) if (rep != rhs->node) { auto newrhs = make_compare(rep, rhs->isEqual, rhs->Loop, ctx); @@ -5625,6 +5678,7 @@ return true; } if (rhs->isEqual) { + if (rhs->Loop) if (auto rep = evaluateAtLoopIter(node, ctx.SE, rhs->Loop, rhs->node)) if (rep != node) { auto newlhs = make_compare(rep, isEqual, Loop, ctx); @@ -5760,11 +5814,13 @@ return true; else insert(ivVals, shared_from_this()); + ConstraintContext ctxd(ctx, shared_from_this(), rhs); + for (const auto &iv : ivVals) { SetTy nextunionVals; bool midchanged = false; for (auto &uv : unionVals) { - auto tmp = iv->andB(uv, ctx); + auto tmp = iv->andB(uv, ctxd); if (!tmp) { midchanged = false; nextunionVals = unionVals; @@ -5814,8 +5870,10 @@ return true; if (changed) { auto cur = Constraints::none(); - for (auto uv : unionVals) - cur = cur->orB(uv, ctx); + for (auto uv : unionVals) { + cur = cur->orB(uv, ctxd); + if (!cur) break; + } if (*cur != *rhs) return andB(cur, ctx); @@ -6017,11 +6075,9 @@ Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP, if (isEqual) { return {std::make_pair(Exp.expandCodeFor(node, T, IP), cond)}; } - llvm::errs() << *this << "\n"; - llvm_unreachable("Constraint ne not handled"); - // EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", - // F, "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, "\n No - // sparsification: not sparse solvable(nosoltn): solutions:",*solutions); + EmitFailure("NoSparsification", IP->getDebugLoc(), IP, "Negated solution not handled: ", *this); + assert(0); + return {}; } case Type::Union: { SmallVector, 1> vals; @@ -6031,21 +6087,36 @@ Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP, return vals; } case Type::Intersect:{ + { SmallVector vals(values.begin(), values.end()); + ssize_t unionidx = -1; for (int i=0; ity == Type::Union) { + unionidx = i; + bool allne = true; + for (auto &v : vals[i]->values) { + if (v->ty != Type::Compare || + v->isEqual) { + allne = false; + break; + } + } + if (allne) break; + } + } + if (unionidx != -1) { auto others = Constraints::all(); for (int j=0; jandB(vals[j], ctx); SmallVector, 1> resvals; - for (auto &v : vals[i]->values) { + for (auto &v : vals[unionidx]->values) { auto tmp = v->andB(others, ctx); for (const auto& sol : tmp->allSolutions(Exp, T, IP, ctx, B)) resvals.push_back(sol); } return resvals; - } + } } Value *solVal = nullptr; Value *cond = ConstantInt::getTrue(T->getContext()); @@ -6284,15 +6355,13 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, // Full simplification while (!Q.empty()) { auto cur = Q.pop_back_val(); - /* std::set prev; for (auto v : Q) prev.insert(v); llvm::errs() << "\n\n\n\n" << F << "\ncur: " << *cur << "\n"; - */ auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL); (void)changed; - /* + if (changed) { llvm::errs() << "changed: " << *changed << "\n"; @@ -6301,9 +6370,10 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, llvm::errs() << " + " << *I << "\n"; llvm::errs() << F << "\n\n"; } - */ } + llvm::errs() << " post fix inner " << F << "\n"; + SmallVector, 1> sparseBlocks; bool legalToSparse = true; for (auto &B : F) @@ -6490,8 +6560,19 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, forSparsification[L].second.emplace_back(blk, solutions); } - if (sawError) + if (sawError) { + for (auto & pair : forSparsification) { + for (auto PN : {pair.second.first.first, pair.second.first.second}) { + PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->eraseFromParent(); + } + } + if (llvm::verifyFunction(F, &llvm::errs())) { + llvm::errs() << F << "\n"; + report_fatal_error("function failed verification (6)"); + } return; + } if (forSparsification.size() == 0) { llvm::errs() << " found no stores for sparsification\n"; diff --git a/enzyme/test/Integration/Sparse/ringspring2Dextenddata.cpp b/enzyme/test/Integration/Sparse/ringspring2Dextenddata.cpp index 1736cfa520e2..add70c298308 100644 --- a/enzyme/test/Integration/Sparse/ringspring2Dextenddata.cpp +++ b/enzyme/test/Integration/Sparse/ringspring2Dextenddata.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include @@ -26,8 +25,6 @@ struct triple { }; -size_t N; - extern int enzyme_dup; extern int enzyme_dupnoneed; extern int enzyme_out; @@ -73,7 +70,7 @@ double ident_load(size_t idx, size_t i, size_t N) { } __attribute__((enzyme_sparse_accumulate)) -void inner_store(int64_t row, int64_t col, double val, std::vector &triplets) { +void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { printf("row=%d col=%d val=%f\n", row, col % N, val); // assert(abs(val) > 0.00001); triplets.emplace_back(row % N, col % N, val); @@ -83,7 +80,7 @@ __attribute__((always_inline)) void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { if (val == 0.0) return; idx /= sizeof(double); - inner_store(i, idx, val, triplets); + inner_store(i, idx, N, val, triplets); } __attribute__((always_inline)) @@ -132,8 +129,6 @@ std::vector hess_f2(size_t N, double* input) { */ // int argc, char** argv int __attribute__((always_inline)) main() { - std::mt19937 generator(0); // Seed the random number generator - std::uniform_real_distribution normal(0, 0.05); // if (argc != 2) { @@ -147,8 +142,8 @@ int __attribute__((always_inline)) main() { double x[2 * N + 2]; for (int i = 0; i < N; ++i) { double angle = 2 * M_PI * i / N; - x[2 * i] = cos(angle) + normal(generator); - x[2 * i + 1] = sin(angle) + normal(generator); + x[2 * i] = cos(angle) ;//+ normal(generator); + x[2 * i + 1] = sin(angle) ;//+ normal(generator); } x[2 * N] = x[0]; x[2 * N + 1] = x[1]; diff --git a/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp b/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp index 1fb062f4faa9..9cfb73e97e66 100644 --- a/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp +++ b/enzyme/test/Integration/Sparse/ringspring3Dextenddata.cpp @@ -1,8 +1,8 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -mllvm -enable-load-pre=0 | %lli - ; fi // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi @@ -26,8 +26,6 @@ struct triple { }; -size_t N; - extern int enzyme_dup; extern int enzyme_dupnoneed; extern int enzyme_out; @@ -43,7 +41,9 @@ extern double* __enzyme_todense(void *, ...) noexcept; __attribute__((always_inline)) static double f(size_t N, double* pos) { double e = 0.; - for (size_t i = 0; i < N; i += 3) { + __builtin_assume(N != 0); + for (size_t i = 0; i < N; i+=3) { + __builtin_assume(i < 1000000000); double vx = pos[i]; double vy = pos[i + 1]; double vz = pos[i + 2]; @@ -74,7 +74,7 @@ static double ident_load(int64_t idx, size_t i, size_t N) { } __attribute__((enzyme_sparse_accumulate)) -static void inner_store(int64_t row, int64_t col, double val, std::vector &triplets) { +static void inner_store(int64_t row, int64_t col, size_t N, double val, std::vector &triplets) { printf("row=%d col=%d val=%f\n", row, col % N, val); // assert(abs(val) > 0.00001); triplets.emplace_back(row % N, col % N, val); @@ -84,7 +84,7 @@ __attribute__((always_inline)) static void sparse_store(double val, int64_t idx, size_t i, size_t N, std::vector &triplets) { if (val == 0.0) return; idx /= sizeof(double); - inner_store(i, idx, val, triplets); + inner_store(i, idx, N, val, triplets); } __attribute__((always_inline)) @@ -108,6 +108,7 @@ std::vector hess_f(size_t N, double* input) { std::vector triplets; // input = __enzyme_todense((void*)mod_load, (void*)never_store, input, N); __builtin_assume(N > 0); + __builtin_assume(N < 10000000000); for (size_t i=0; i