Skip to content

Commit

Permalink
extend constraint language
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 9, 2024
1 parent fe1ed67 commit 0c6114c
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 56 deletions.
170 changes: 115 additions & 55 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3900,6 +3900,23 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
}
}

// (a | b) == 0 -> a == 0 & b == 0
if (auto icmp = dyn_cast<ICmpInst>(cur))
if (icmp->getPredicate() == ICmpInst::ICMP_EQ)
for (int i=0; i<2; i++)
if (auto C = dyn_cast<ConstantInt>(icmp->getOperand(i)))
if (C->isZero())
if (auto z = dyn_cast<BinaryOperator>(icmp->getOperand(1-i)))
if (z->getOpcode() == BinaryOperator::Or) {
auto a0 = pushcse(B.CreateICmpEQ(z->getOperand(0), C));
auto b0 = pushcse(B.CreateICmpEQ(z->getOperand(1), C));
auto res = pushcse(B.CreateAnd(a0, b0));
push(z);
push(icmp);
replaceAndErase(cur, res);
return "OrEQZero";
}

// add (mul a b), (mul c, b) -> mul (add a, c), b
if (cur->getOpcode() == Instruction::Sub ||
cur->getOpcode() == Instruction::Add) {
Expand Down Expand Up @@ -5139,6 +5156,11 @@ class Constraints : public std::enable_shared_from_this<Constraints> {
else if (lhs->isEqual > rhs->isEqual)
return false;

if (lhs->Loop < rhs->Loop)
return true;
else if (lhs->Loop > rhs->Loop)
return false;

return lhs->values < rhs->values;
/*
auto lhss = lhs->values.size();
Expand All @@ -5164,19 +5186,21 @@ class Constraints : public std::enable_shared_from_this<Constraints> {
const SCEV *const node;
// whether equal to the node, or not equal to the node
bool isEqual;
// the loop of the iv comparing against.
const llvm::Loop* const Loop;
// using SetTy = SmallVector<InnerTy, 0>;
// using SetTy = SetVector<InnerTy, SmallVector<InnerTy, 0>,
// std::set<InnerTy>>;

Constraints() : ty(Type::Union), values(), node(nullptr), isEqual(false) {}
Constraints() : ty(Type::Union), values(), node(nullptr), isEqual(false), Loop(nullptr) {}

Constraints(const SCEV *v, bool isEqual)
: ty(Type::Compare), values(), node(v), isEqual(isEqual) {}
Constraints(Type t) : ty(t), values(), node(nullptr), isEqual(false) {
Constraints(const SCEV *v, bool isEqual, const llvm::Loop* Loop)
: ty(Type::Compare), values(), node(v), isEqual(isEqual), Loop(Loop) {}
Constraints(Type t) : ty(t), values(), node(nullptr), isEqual(false), Loop(nullptr) {
assert(t == Type::All || t == Type::None);
}
Constraints(Type t, const SetTy &c)
: ty(t), values(c), node(nullptr), isEqual(false) {
: ty(t), values(c), node(nullptr), isEqual(false), Loop(nullptr) {
assert(t != Type::All);
assert(t != Type::None);
assert(c.size() != 0);
Expand All @@ -5198,6 +5222,9 @@ class Constraints : public std::enable_shared_from_this<Constraints> {
if (isEqual != rhs.isEqual) {
return false;
}
if (Loop != rhs.Loop) {
return false;
}
if (values.size() != rhs.values.size()) {
return false;
}
Expand Down Expand Up @@ -5233,6 +5260,12 @@ return true;
if (isEqual > rhs.isEqual) {
return false;
}
if (Loop < rhs.Loop) {
return true;
}
if (Loop > rhs.Loop) {
return false;
}
if (values.size() < rhs.values.size()) {
return true;
}
Expand All @@ -5256,6 +5289,7 @@ return true;
unsigned hash() const {
unsigned res = 5 * (unsigned)ty +
DenseMapInfo<const SCEV *>::getHashValue(node) + isEqual;
res = llvm::detail::combineHashValue(res, (unsigned)(size_t)Loop);
for (auto v : values)
res = llvm::detail::combineHashValue(res, v->hash());
return res;
Expand Down Expand Up @@ -5308,7 +5342,7 @@ return true;
case Type::All:
return Constraints::none();
case Type::Compare:
return std::make_shared<Constraints>(node, !isEqual);
return std::make_shared<Constraints>(node, !isEqual, Loop);
case Type::Union: {
// not of or's is and of not's
SetTy next;
Expand Down Expand Up @@ -5437,6 +5471,7 @@ return true;

if (ty == Type::Compare && rhs->ty == Type::Compare) {
auto sub = SE.getMinusSCEV(node, rhs->node);
if (Loop == rhs->Loop)
if (auto cst = dyn_cast<SCEVConstant>(sub)) {
// the two solves are equivalent to each other
if (cst->getValue()->isZero()) {
Expand Down Expand Up @@ -5476,6 +5511,9 @@ return true;
SetTy vals;
insert(vals, shared_from_this());
insert(vals, rhs);
if (vals.size() == 1) {
llvm::errs() << "this: " << *this << " rhs: " << *rhs << "\n";
}
return std::make_shared<Constraints>(Type::Intersect, vals);
}
if (ty == Type::Intersect && rhs->ty == Type::Intersect) {
Expand Down Expand Up @@ -5518,10 +5556,16 @@ return true;
insert(vals, rhs);
return std::make_shared<Constraints>(Type::Intersect, vals);
}
if (ty == Type::Intersect && rhs->ty == Type::Union) {
if ((ty == Type::Intersect || ty == Type::Compare) && rhs->ty == Type::Union) {
SetTy unionVals = rhs->values;
bool changed = false;
for (const auto &iv : values) {
SetTy ivVals;
if (ty == Type::Intersect)
ivVals = values;
else
insert(ivVals, shared_from_this());

for (const auto &iv : ivVals) {
SetTy nextunionVals;
for (auto &uv : unionVals) {
auto tmp = iv->andB(uv, SE);
Expand All @@ -5530,7 +5574,7 @@ return true;
case Type::Compare:
case Type::Union:
insert(nextunionVals, tmp);
changed = true;
changed |= tmp != uv;
break;
case Type::Intersect:
insert(nextunionVals, uv);
Expand All @@ -5550,12 +5594,12 @@ return true;
return andB(cur, SE);
}

SetTy vals = values;
SetTy vals = ivVals;
insert(vals, rhs);
return std::make_shared<Constraints>(Type::Intersect, vals);
}
// Handled above via symmetry
if (rhs->ty == Type::Intersect) {
if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) {
return rhs->andB(shared_from_this(), SE);
}
// (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) and
Expand All @@ -5579,6 +5623,7 @@ return true;
auto res = std::make_shared<Constraints>(Type::Intersect, vals);
return res;
}
llvm::errs() << " andB this: " << *this << " rhs: " << *rhs << "\n";
llvm_unreachable("Illegal predicate state");
}
// what this would be like when removing the following list of constraints
Expand All @@ -5599,30 +5644,8 @@ return true;
return std::make_shared<Constraints>(ty, res);
}
}
SmallVector<Value *, 1> allSolutions(SCEVExpander &Exp, llvm::Type *T,
Instruction *IP) const;
bool canEvaluateSolutions() const {
switch (ty) {
case Type::None:
return true;
case Type::All:
return false;
case Type::Compare:
if (isEqual) {
return true;
}
return false;
case Type::Union: {
for (auto v : values)
if (!v->canEvaluateSolutions())
return false;
return true;
}
case Type::Intersect:
return false;
}
return false;
}
SmallVector<std::pair<Value *, Value*>, 1> allSolutions(SCEVExpander &Exp, llvm::Type *T,
Instruction *IP, const llvm::Loop* ivToSolve, IRBuilder<>& B) const;
};

raw_ostream &operator<<(raw_ostream &os, const Constraints &c) {
Expand All @@ -5647,41 +5670,72 @@ raw_ostream &operator<<(raw_ostream &os, const Constraints &c) {
}
case Constraints::Type::Compare: {
if (c.isEqual) {
os << "(eq " << *c.node << ")";
os << "(eq " << *c.node << ", L=" << c.Loop->getHeader()->getName() << ")";
} else {
os << "(ne " << *c.node << ")";
os << "(ne " << *c.node << ", L=" << c.Loop->getHeader()->getName() << ")";
}
return os;
}
}
return os;
}

SmallVector<Value *, 1> Constraints::allSolutions(SCEVExpander &Exp,
SmallVector<std::pair<Value *, Value*>, 1> Constraints::allSolutions(SCEVExpander &Exp,
llvm::Type *T,
Instruction *IP) const {
Instruction *IP,
const llvm::Loop* ivToSolve, IRBuilder<>& B) const {
switch (ty) {
case Type::None:
return {};
case Type::All:
llvm::errs() << *this << "\n";
llvm_unreachable("All not handled");
case Type::Compare:
case Type::Compare: {
Value* cond = ConstantInt::getTrue(T->getContext());
if (ivToSolve != Loop) {
assert(ivToSolve);
Value* ivVal = Exp.expandCodeFor(node, T, IP);
if (isEqual)
cond = B.CreateICmpEQ(ivVal, Loop->getInductionVariable(*Exp.getSE()));
else
cond = B.CreateICmpNE(ivVal, Loop->getInductionVariable(*Exp.getSE()));
return {std::make_pair((Value*)nullptr, cond)};
}
if (isEqual) {
return {Exp.expandCodeFor(node, T, IP)};
return {std::make_pair(Exp.expandCodeFor(node, T, IP), cond)};
}
llvm::errs() << *this << "\n";
llvm_unreachable("Constraint ne not handled");
// EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F, "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, "\n No sparsification: not sparse solvable(nosoltn): solutions:",*solutions);
}
case Type::Union: {
SmallVector<Value *, 1> vals;
SmallVector<std::pair<Value *, Value*>, 1> vals;
for (auto v : values)
for (auto sol : v->allSolutions(Exp, T, IP))
for (auto sol : v->allSolutions(Exp, T, IP, ivToSolve, B))
vals.push_back(sol);
return vals;
}
case Type::Intersect:
llvm::errs() << *this << "\n";
llvm_unreachable("Intersect not handled");
Value* solVal = nullptr;
Value* cond = ConstantInt::getTrue(T->getContext());
for (auto v : values) {
auto sols = v->allSolutions(Exp, T, IP, ivToSolve, B);
if (sols.size() != 1) {
llvm::errs() << *this << "\n";
llvm_unreachable("Intersect not handled");
}
auto sol = sols[0];
if (sol.first) {
if (solVal == nullptr) {
llvm::errs() << *this << "\n";
llvm_unreachable("Intersect not handled");
}
assert(solVal == nullptr);
solVal = sol.first;
}
cond = B.CreateAnd(cond, sol.second);
}
return { std::make_pair(solVal, cond) };
}
return {};
}
Expand Down Expand Up @@ -5879,11 +5933,13 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
if (auto icmp = dyn_cast<ICmpInst>(I)) {
auto lhs = SE.getSCEVAtScope(icmp->getOperand(0), L);
auto rhs = SE.getSCEVAtScope(icmp->getOperand(1), L);
llvm::errs() << " lhs: " << *lhs << "\n";
llvm::errs() << " rhs: " << *rhs << "\n";

auto sub1 = SE.getMinusSCEV(lhs, rhs);

if ( icmp->getPredicate() == ICmpInst::ICMP_EQ || icmp->getPredicate() == ICmpInst::ICMP_NE)
if (auto add = dyn_cast<SCEVAddRecExpr>(sub1)) {
if (add->getLoop() == L) {
if (add->isAffine()) {
// 0 === A + B * inc -> -A / B = inc
auto A = add->getStart();
Expand All @@ -5899,13 +5955,12 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
auto div_e = SE.getUDivExactExpr(MA, B);
if (div == div_e) {
auto res = std::make_shared<Constraints>(
div, icmp->getPredicate() == ICmpInst::ICMP_EQ);
div, icmp->getPredicate() == ICmpInst::ICMP_EQ, add->getLoop());
llvm::errs() << " getSparse(icmp, " << *I << ") = " << *res << "\n";
return res;
}
}
}
}
EmitFailure("NoSparsification", I->getDebugLoc(), I, " No sparsification: not sparse solvable(scev): ", *sub1);
legal = false;
return Constraints::all();
Expand Down Expand Up @@ -5950,12 +6005,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
continue;
}

if (!solutions->canEvaluateSolutions()) {
EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F, "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, "\n No sparsification: not sparse solvable(solneval): solutions:",*solutions);
sawError = true;
continue;
}
if (solutions == Constraints::none()) {
if (solutions == Constraints::none() || solutions == Constraints::all()) {
EmitFailure("NoSparsification", context->getDebugLoc(), context, "F: ", F, "\nL: ", *L, "\ncond: ", *cond, " negated:", negated, "\n No sparsification: not sparse solvable(nosoltn): solutions:",*solutions);
sawError = true;
}
Expand Down Expand Up @@ -6117,11 +6167,21 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
for (auto en : llvm::enumerate(pair.second.second)) {
auto off = en.index();
auto &solutions = en.value().second;
for (auto sol : solutions->allSolutions(Exp, idxty, phterm)) {
for (auto [sol, condition] : solutions->allSolutions(Exp, idxty, phterm, L, B)) {
SmallVector<Value *, 1> args(Inputs.begin(), Inputs.end());
args[off_idx] = ConstantInt::get(idxty, off);
args[induct_idx] = sol;
auto BB = B.GetInsertBlock();
auto B2 = BB->splitBasicBlock(BB->getTerminator(), "poststore");
B2->moveAfter(BB);
BB->getTerminator()->eraseFromParent();
B.SetInsertPoint(BB);
auto callB = BasicBlock::Create(BB->getContext(), "tostore", BB->getParent(), B2);
B.CreateCondBr(condition, callB, B2);
B.SetInsertPoint(callB);
B.CreateCall(F2, args);
B.CreateBr(B2);
B.SetInsertPoint(B2);
}
auto blk = en.value().first;
auto term = blk->getTerminator();
Expand Down
4 changes: 3 additions & 1 deletion enzyme/test/Integration/Sparse/ringspring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ static void ident_store(double , int64_t idx, size_t i) {
__attribute__((always_inline))
double ident_load(int64_t idx, size_t i, size_t N) {
idx /= sizeof(double);
return (double)(idx % N == i % N);// ? 1.0 : 0.0;
// return (double)( ( (idx == N) ? 0 : idx) == i);
return (double)((idx != N && idx == i) || (idx == N && 0 == i));
// return (double)( idx % N == i);
}

__attribute__((enzyme_sparse_accumulate))
Expand Down

0 comments on commit 0c6114c

Please sign in to comment.