diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 3425b265e1f1..1f13cbceced8 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -3900,6 +3900,23 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } + // (a | b) == 0 -> a == 0 & b == 0 + if (auto icmp = dyn_cast(cur)) + if (icmp->getPredicate() == ICmpInst::ICMP_EQ) + for (int i=0; i<2; i++) + if (auto C = dyn_cast(icmp->getOperand(i))) + if (C->isZero()) + if (auto z = dyn_cast(icmp->getOperand(1-i))) + if (z->getOpcode() == BinaryOperator::Or) { + auto a0 = pushcse(B.CreateICmpEQ(z->getOperand(0), C)); + auto b0 = pushcse(B.CreateICmpEQ(z->getOperand(1), C)); + auto res = pushcse(B.CreateAnd(a0, b0)); + push(z); + push(icmp); + replaceAndErase(cur, res); + return "OrEQZero"; + } + // add (mul a b), (mul c, b) -> mul (add a, c), b if (cur->getOpcode() == Instruction::Sub || cur->getOpcode() == Instruction::Add) { @@ -5139,6 +5156,11 @@ class Constraints : public std::enable_shared_from_this { else if (lhs->isEqual > rhs->isEqual) return false; + if (lhs->Loop < rhs->Loop) + return true; + else if (lhs->Loop > rhs->Loop) + return false; + return lhs->values < rhs->values; /* auto lhss = lhs->values.size(); @@ -5164,19 +5186,21 @@ class Constraints : public std::enable_shared_from_this { const SCEV *const node; // whether equal to the node, or not equal to the node bool isEqual; + // the loop of the iv comparing against. + const llvm::Loop* const Loop; // using SetTy = SmallVector; // using SetTy = SetVector, // std::set>; - Constraints() : ty(Type::Union), values(), node(nullptr), isEqual(false) {} + Constraints() : ty(Type::Union), values(), node(nullptr), isEqual(false), Loop(nullptr) {} - Constraints(const SCEV *v, bool isEqual) - : ty(Type::Compare), values(), node(v), isEqual(isEqual) {} - Constraints(Type t) : ty(t), values(), node(nullptr), isEqual(false) { + Constraints(const SCEV *v, bool isEqual, const llvm::Loop* Loop) + : ty(Type::Compare), values(), node(v), isEqual(isEqual), Loop(Loop) {} + Constraints(Type t) : ty(t), values(), node(nullptr), isEqual(false), Loop(nullptr) { assert(t == Type::All || t == Type::None); } Constraints(Type t, const SetTy &c) - : ty(t), values(c), node(nullptr), isEqual(false) { + : ty(t), values(c), node(nullptr), isEqual(false), Loop(nullptr) { assert(t != Type::All); assert(t != Type::None); assert(c.size() != 0); @@ -5198,6 +5222,9 @@ class Constraints : public std::enable_shared_from_this { if (isEqual != rhs.isEqual) { return false; } + if (Loop != rhs.Loop) { + return false; + } if (values.size() != rhs.values.size()) { return false; } @@ -5233,6 +5260,12 @@ return true; if (isEqual > rhs.isEqual) { return false; } + if (Loop < rhs.Loop) { + return true; + } + if (Loop > rhs.Loop) { + return false; + } if (values.size() < rhs.values.size()) { return true; } @@ -5256,6 +5289,7 @@ return true; unsigned hash() const { unsigned res = 5 * (unsigned)ty + DenseMapInfo::getHashValue(node) + isEqual; + res = llvm::detail::combineHashValue(res, (unsigned)(size_t)Loop); for (auto v : values) res = llvm::detail::combineHashValue(res, v->hash()); return res; @@ -5308,7 +5342,7 @@ return true; case Type::All: return Constraints::none(); case Type::Compare: - return std::make_shared(node, !isEqual); + return std::make_shared(node, !isEqual, Loop); case Type::Union: { // not of or's is and of not's SetTy next; @@ -5437,6 +5471,7 @@ return true; if (ty == Type::Compare && rhs->ty == Type::Compare) { auto sub = SE.getMinusSCEV(node, rhs->node); + if (Loop == rhs->Loop) if (auto cst = dyn_cast(sub)) { // the two solves are equivalent to each other if (cst->getValue()->isZero()) { @@ -5476,6 +5511,9 @@ return true; SetTy vals; insert(vals, shared_from_this()); insert(vals, rhs); + if (vals.size() == 1) { + llvm::errs() << "this: " << *this << " rhs: " << *rhs << "\n"; + } return std::make_shared(Type::Intersect, vals); } if (ty == Type::Intersect && rhs->ty == Type::Intersect) { @@ -5518,10 +5556,16 @@ return true; insert(vals, rhs); return std::make_shared(Type::Intersect, vals); } - if (ty == Type::Intersect && rhs->ty == Type::Union) { + if ((ty == Type::Intersect || ty == Type::Compare) && rhs->ty == Type::Union) { SetTy unionVals = rhs->values; bool changed = false; - for (const auto &iv : values) { + SetTy ivVals; + if (ty == Type::Intersect) + ivVals = values; + else + insert(ivVals, shared_from_this()); + + for (const auto &iv : ivVals) { SetTy nextunionVals; for (auto &uv : unionVals) { auto tmp = iv->andB(uv, SE); @@ -5530,7 +5574,7 @@ return true; case Type::Compare: case Type::Union: insert(nextunionVals, tmp); - changed = true; + changed |= tmp != uv; break; case Type::Intersect: insert(nextunionVals, uv); @@ -5550,12 +5594,12 @@ return true; return andB(cur, SE); } - SetTy vals = values; + SetTy vals = ivVals; insert(vals, rhs); return std::make_shared(Type::Intersect, vals); } // Handled above via symmetry - if (rhs->ty == Type::Intersect) { + if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) { return rhs->andB(shared_from_this(), SE); } // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) and @@ -5579,6 +5623,7 @@ return true; auto res = std::make_shared(Type::Intersect, vals); return res; } + llvm::errs() << " andB this: " << *this << " rhs: " << *rhs << "\n"; llvm_unreachable("Illegal predicate state"); } // what this would be like when removing the following list of constraints @@ -5599,30 +5644,8 @@ return true; return std::make_shared(ty, res); } } - SmallVector allSolutions(SCEVExpander &Exp, llvm::Type *T, - Instruction *IP) const; - bool canEvaluateSolutions() const { - switch (ty) { - case Type::None: - return true; - case Type::All: - return false; - case Type::Compare: - if (isEqual) { - return true; - } - return false; - case Type::Union: { - for (auto v : values) - if (!v->canEvaluateSolutions()) - return false; - return true; - } - case Type::Intersect: - return false; - } - return false; - } + SmallVector, 1> allSolutions(SCEVExpander &Exp, llvm::Type *T, + Instruction *IP, const llvm::Loop* ivToSolve, IRBuilder<>& B) const; }; raw_ostream &operator<<(raw_ostream &os, const Constraints &c) { @@ -5647,9 +5670,9 @@ raw_ostream &operator<<(raw_ostream &os, const Constraints &c) { } case Constraints::Type::Compare: { if (c.isEqual) { - os << "(eq " << *c.node << ")"; + os << "(eq " << *c.node << ", L=" << c.Loop->getHeader()->getName() << ")"; } else { - os << "(ne " << *c.node << ")"; + os << "(ne " << *c.node << ", L=" << c.Loop->getHeader()->getName() << ")"; } return os; } @@ -5657,31 +5680,62 @@ raw_ostream &operator<<(raw_ostream &os, const Constraints &c) { return os; } -SmallVector Constraints::allSolutions(SCEVExpander &Exp, +SmallVector, 1> Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, - Instruction *IP) const { + Instruction *IP, + const llvm::Loop* ivToSolve, IRBuilder<>& B) const { switch (ty) { case Type::None: return {}; case Type::All: llvm::errs() << *this << "\n"; llvm_unreachable("All not handled"); - case Type::Compare: + case Type::Compare: { + Value* cond = ConstantInt::getTrue(T->getContext()); + if (ivToSolve != Loop) { + assert(ivToSolve); + Value* ivVal = Exp.expandCodeFor(node, T, IP); + if (isEqual) + cond = B.CreateICmpEQ(ivVal, Loop->getInductionVariable(*Exp.getSE())); + else + cond = B.CreateICmpNE(ivVal, Loop->getInductionVariable(*Exp.getSE())); + return {std::make_pair((Value*)nullptr, cond)}; + } if (isEqual) { - return {Exp.expandCodeFor(node, T, IP)}; + return {std::make_pair(Exp.expandCodeFor(node, T, IP), cond)}; } llvm::errs() << *this << "\n"; llvm_unreachable("Constraint ne not handled"); + // EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F, "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, "\n No sparsification: not sparse solvable(nosoltn): solutions:",*solutions); + } case Type::Union: { - SmallVector vals; + SmallVector, 1> vals; for (auto v : values) - for (auto sol : v->allSolutions(Exp, T, IP)) + for (auto sol : v->allSolutions(Exp, T, IP, ivToSolve, B)) vals.push_back(sol); return vals; } case Type::Intersect: - llvm::errs() << *this << "\n"; - llvm_unreachable("Intersect not handled"); + Value* solVal = nullptr; + Value* cond = ConstantInt::getTrue(T->getContext()); + for (auto v : values) { + auto sols = v->allSolutions(Exp, T, IP, ivToSolve, B); + if (sols.size() != 1) { + llvm::errs() << *this << "\n"; + llvm_unreachable("Intersect not handled"); + } + auto sol = sols[0]; + if (sol.first) { + if (solVal == nullptr) { + llvm::errs() << *this << "\n"; + llvm_unreachable("Intersect not handled"); + } + assert(solVal == nullptr); + solVal = sol.first; + } + cond = B.CreateAnd(cond, sol.second); + } + return { std::make_pair(solVal, cond) }; } return {}; } @@ -5879,11 +5933,13 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, if (auto icmp = dyn_cast(I)) { auto lhs = SE.getSCEVAtScope(icmp->getOperand(0), L); auto rhs = SE.getSCEVAtScope(icmp->getOperand(1), L); + llvm::errs() << " lhs: " << *lhs << "\n"; + llvm::errs() << " rhs: " << *rhs << "\n"; auto sub1 = SE.getMinusSCEV(lhs, rhs); + if ( icmp->getPredicate() == ICmpInst::ICMP_EQ || icmp->getPredicate() == ICmpInst::ICMP_NE) if (auto add = dyn_cast(sub1)) { - if (add->getLoop() == L) { if (add->isAffine()) { // 0 === A + B * inc -> -A / B = inc auto A = add->getStart(); @@ -5899,13 +5955,12 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, auto div_e = SE.getUDivExactExpr(MA, B); if (div == div_e) { auto res = std::make_shared( - div, icmp->getPredicate() == ICmpInst::ICMP_EQ); + div, icmp->getPredicate() == ICmpInst::ICMP_EQ, add->getLoop()); llvm::errs() << " getSparse(icmp, " << *I << ") = " << *res << "\n"; return res; } } } - } EmitFailure("NoSparsification", I->getDebugLoc(), I, " No sparsification: not sparse solvable(scev): ", *sub1); legal = false; return Constraints::all(); @@ -5950,12 +6005,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, continue; } - if (!solutions->canEvaluateSolutions()) { - EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F, "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, "\n No sparsification: not sparse solvable(solneval): solutions:",*solutions); - sawError = true; - continue; - } - if (solutions == Constraints::none()) { + if (solutions == Constraints::none() || solutions == Constraints::all()) { EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F, "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, "\n No sparsification: not sparse solvable(nosoltn): solutions:",*solutions); sawError = true; } @@ -6117,11 +6167,21 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM, for (auto en : llvm::enumerate(pair.second.second)) { auto off = en.index(); auto &solutions = en.value().second; - for (auto sol : solutions->allSolutions(Exp, idxty, phterm)) { + for (auto [sol, condition] : solutions->allSolutions(Exp, idxty, phterm, L, B)) { SmallVector args(Inputs.begin(), Inputs.end()); args[off_idx] = ConstantInt::get(idxty, off); args[induct_idx] = sol; + auto BB = B.GetInsertBlock(); + auto B2 = BB->splitBasicBlock(BB->getTerminator(), "poststore"); + B2->moveAfter(BB); + BB->getTerminator()->eraseFromParent(); + B.SetInsertPoint(BB); + auto callB = BasicBlock::Create(BB->getContext(), "tostore", BB->getParent(), B2); + B.CreateCondBr(condition, callB, B2); + B.SetInsertPoint(callB); B.CreateCall(F2, args); + B.CreateBr(B2); + B.SetInsertPoint(B2); } auto blk = en.value().first; auto term = blk->getTerminator(); diff --git a/enzyme/test/Integration/Sparse/ringspring.cpp b/enzyme/test/Integration/Sparse/ringspring.cpp index 8bb40d96c157..99637763cf84 100644 --- a/enzyme/test/Integration/Sparse/ringspring.cpp +++ b/enzyme/test/Integration/Sparse/ringspring.cpp @@ -66,7 +66,9 @@ static void ident_store(double , int64_t idx, size_t i) { __attribute__((always_inline)) double ident_load(int64_t idx, size_t i, size_t N) { idx /= sizeof(double); - return (double)(idx % N == i % N);// ? 1.0 : 0.0; + // return (double)( ( (idx == N) ? 0 : idx) == i); + return (double)((idx != N && idx == i) || (idx == N && 0 == i)); + // return (double)( idx % N == i); } __attribute__((enzyme_sparse_accumulate))