Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 16, 2024
1 parent 1cead75 commit 57184ed
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 57 deletions.
79 changes: 33 additions & 46 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3293,6 +3293,7 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
std::optional<APFloat> constval;
bool changed = false;
for (auto &v : callOperands(P))

{
if (auto P2 = isProduct(v)) {
for (auto &v2 : callOperands(P2)) {
Expand Down Expand Up @@ -6006,9 +6007,6 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
}
if (br->getSuccessor(0) != PN->getParent()) {
continue;



}
if (br->getSuccessor(1) != PN->getIncomingBlock(1 - i)) {
continue;
Expand Down Expand Up @@ -6199,7 +6197,6 @@ class Constraints : public std::enable_shared_from_this<Constraints> {
: ty(t), values(), node(nullptr), isEqual(false), Loop(nullptr) {
assert(t == Type::All || t == Type::None);
}

Constraints(Type t, const SetTy &c, bool check = true)
: ty(t), values(c), node(nullptr), isEqual(false), Loop(nullptr) {
assert(t != Type::All);
Expand Down Expand Up @@ -6561,56 +6558,14 @@ return true;
return newlhs->andB(rhs, ctx);
}
}
}

if (isEqual) {
if (auto addrec = dyn_cast<SCEVAddRecExpr>(rhs->node)) {
if (addrec->isAffine() && addrec->getLoop() == Loop) {
auto node2 = addrec->evaluateAtIteration(node, ctx.SE);
auto newrhs = std::make_shared<Constraints>(node2, rhs->isEqual, rhs->Loop);
return andB(newrhs, ctx);
}
}
// not loop -> node == 0
if (!Loop) {
for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node), ctx.SE.getMinusSCEV(rhs->node, node)}) {
llvm::errs() << " maybe replace lhs: " << *this << " rhs: " << *rhs << " sub1: " << *sub1 << "\n";
auto newrhs = std::make_shared<Constraints>(sub1, rhs->isEqual, rhs->Loop);
if (*newrhs == *this) return shared_from_this();
if (!isa<SCEVConstant>(rhs->node) && isa<SCEVConstant>(sub1)) {
return andB(newrhs, ctx);
}
}
}
}

if (rhs->isEqual) {
if (auto addrec = dyn_cast<SCEVAddRecExpr>(node)) {
if (addrec->isAffine() && addrec->getLoop() == rhs->Loop) {
auto node2 = addrec->evaluateAtIteration(rhs->node, ctx.SE);
auto newlhs = std::make_shared<Constraints>(node2, isEqual, Loop);
return newlhs->andB(rhs, ctx);
}
}
// not loop -> node == 0
if (!rhs->Loop) {
for (auto sub1 : {ctx.SE.getMinusSCEV(node, rhs->node), ctx.SE.getMinusSCEV(rhs->node, node)}) {
llvm::errs() << " maybe replace lhs2: " << *this << " rhs: " << *rhs << " sub1: " << *sub1 << "\n";
auto newlhs = std::make_shared<Constraints>(sub1, isEqual, Loop);
if (*newlhs == *this) return shared_from_this();
if (!isa<SCEVConstant>(node) && isa<SCEVConstant>(sub1)) {
return newlhs->andB(rhs, ctx);
}
}
}
}

if (!Loop && !rhs->Loop && isEqual == rhs->isEqual) {
if (node == ctx.SE.getNegativeSCEV(rhs->node))
return shared_from_this();
}


SetTy vals;
insert(vals, shared_from_this());
insert(vals, rhs);
Expand Down Expand Up @@ -6688,6 +6643,35 @@ return true;
if (legal) {
return newlhs->andB(tmp, ctx);
}
}
insert(vals, v);
}
}
}
if (!foldedIn) {
insert(vals, rhs);
return std::make_shared<Constraints>(Type::Intersect, vals);
} else {
auto cur = Constraints::all();
for (auto &iv : vals) {
auto cur2 = cur->andB(iv, ctx);
if (!cur2)
return nullptr;
cur = std::move(cur2);
}
return cur;
}
}
if ((ty == Type::Intersect || ty == Type::Compare) &&
rhs->ty == Type::Union) {
SetTy unionVals = rhs->values;
bool changed = false;
SetTy ivVals;
if (ty == Type::Intersect)
ivVals = values;
else
insert(ivVals, shared_from_this());

ConstraintContext ctxd(ctx, shared_from_this(), rhs);

for (const auto &iv : ivVals) {
Expand Down Expand Up @@ -7351,6 +7335,9 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
// actual store (if not negated) or to the store (if negated)
// if! negated the result may become more false if negated the
// result may become more true

//

// default is condition avoids sparse, negated is condition goes
// to sparse
Instruction *context =
Expand Down
17 changes: 11 additions & 6 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9259,17 +9259,22 @@ bool GradientUtils::needsCacheWholeAllocation(
continue;

if (auto F = CI->getCalledFunction())
if (F->getCallingConv() == CI->getCallingConv()) {
if (F->getCallingConv() == CI->getCallingConv() && !F->empty()) {
bool onlyReturnUses = true;
bool hasReturnUse = true;

for (auto u : F->getArg(idx)->users()) {
if (isa<ReturnInst>(u)) {
hasReturnUse = true;
if (CI->getFunctionType() != F->getFunctionType() ||
idx >= F->getFunctionType()->getNumParams()) {
onlyReturnUses = false;
} else {
for (auto u : F->getArg(idx)->users()) {
if (isa<ReturnInst>(u)) {
hasReturnUse = true;
continue;
}
onlyReturnUses = false;
continue;
}
onlyReturnUses = false;
continue;
}
// The arg itself has no use in the function
if (onlyReturnUses && !hasReturnUse)
Expand Down
5 changes: 3 additions & 2 deletions enzyme/test/Integration/Sparse/ringspring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct triple {
triple(size_t row, size_t col, double val) : row(row), col(col), val(val) {}
};


size_t N = 8;

extern int enzyme_dup;
Expand All @@ -46,8 +47,8 @@ static double f(size_t N, double* input) {
for (size_t i=0; i<N; i++) {
//double sub = input[i] - input[i+1];
// out += sub * sub;
double sub = input[(i + 1) % N] - input[i % N];
out += (sqrt(sub) + 1)*(sqrt(sub) + 1);
double sub = (input[i+1] - input[i]) * (input[i+1] - input[i]);
out += (sqrt(sub) - 1)*(sqrt(sub) - 1);
}
return out;
}
Expand Down
2 changes: 0 additions & 2 deletions enzyme/test/Integration/Sparse/ringspring2Dextenddata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ struct triple {
};


size_t N;

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -ffast-math -mllvm -enable-load-pre=0 -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-auto-sparsity=1 -S | %lli - ; fi


// everything should be always inline

#include <stdio.h>
Expand Down

0 comments on commit 57184ed

Please sign in to comment.