Skip to content

Commit

Permalink
add phi fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 8, 2023
1 parent 70e5271 commit 59ce655
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 52 deletions.
92 changes: 47 additions & 45 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4329,56 +4329,58 @@ std::optional<std::string> fixSparse_inner(Instruction *cur, llvm::Function &F,
}
// phi (idx=0) ? b, a, a -> select (idx == 0), b, a
if (auto L = LI.getLoopFor(PN->getParent()))
if (auto idx = L->getCanonicalInductionVariable())
if (auto PH = L->getLoopPreheader()) {
bool legal = idx != PN;
auto ph_idx = PN->getBasicBlockIndex(PH);
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
if ((int)i == ph_idx)
continue;
auto v = PN->getIncomingValue(i);
if (v != PN->getIncomingValue(1 - ph_idx)) {
legal = false;
break;
}
// The given var must dominate the loop
if (isa<Constant>(v))
continue;
if (isa<Argument>(v))
continue;
// exception for the induction itself, which we handle specially
if (v == idx)
continue;
auto I = cast<Instruction>(v);
if (!DT.dominates(I, PN)) {
legal = false;
break;
}
}
if (legal) {
auto val = PN->getIncomingValue(1 - ph_idx);
push(val);
if (val == idx) {
val = pushcse(
B.CreateSub(idx, ConstantInt::get(idx->getType(), 1)));
if (L->getHeader() == PN->getParent())
if (auto idx = L->getCanonicalInductionVariable())
if (auto PH = L->getLoopPreheader()) {
bool legal = idx != PN;
auto ph_idx = PN->getBasicBlockIndex(PH);
assert(ph_idx >= 0);
for (size_t i = 0; i < PN->getNumIncomingValues(); i++) {
if ((int)i == ph_idx)
continue;
auto v = PN->getIncomingValue(i);
if (v != PN->getIncomingValue(1 - ph_idx)) {
legal = false;
break;
}
// The given var must dominate the loop
if (isa<Constant>(v))
continue;
if (isa<Argument>(v))
continue;
// exception for the induction itself, which we handle specially
if (v == idx)
continue;
auto I = cast<Instruction>(v);
if (!DT.dominates(I, PN)) {
legal = false;
break;
}
}
if (legal) {
auto val = PN->getIncomingValue(1 - ph_idx);
push(val);
if (val == idx) {
val = pushcse(
B.CreateSub(idx, ConstantInt::get(idx->getType(), 1)));
}

auto val2 = PN->getIncomingValue(ph_idx);
push(val2);
auto val2 = PN->getIncomingValue(ph_idx);
push(val2);

auto c0 = ConstantInt::get(idx->getType(), 0);
// if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) {
// val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val);
//} else {
auto eq = pushcse(B.CreateICmpEQ(idx, c0));
val = pushcse(
B.CreateSelect(eq, val2, val, "phisel." + cur->getName()));
//}
auto c0 = ConstantInt::get(idx->getType(), 0);
// if (val2 == c0 && PN->getIncomingValue(1 - ph_idx) == idx) {
// val = B.CreateBinaryIntrinsic(Intrinsic::umax, c0, val);
//} else {
auto eq = pushcse(B.CreateICmpEQ(idx, c0));
val = pushcse(
B.CreateSelect(eq, val2, val, "phisel." + cur->getName()));
//}

replaceAndErase(cur, val);
return "PhiLoop0Sel";
replaceAndErase(cur, val);
return "PhiLoop0Sel";
}
}
}
// phi (sitofp a), (sitofp b) -> sitofp (phi a, b)
{
SmallVector<Value *, 1> negOps;
Expand Down
49 changes: 42 additions & 7 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9158,14 +9158,15 @@ bool GradientUtils::needsCacheWholeAllocation(
return false;
if (!found->second)
return true;
SmallVector<std::pair<const Instruction *, size_t>, 1> todo;
// User, operand of input, whehter the input is the original allocation
SmallVector<std::tuple<const Instruction *, size_t, bool>, 1> todo;
for (auto &use : origInst->uses())
todo.push_back(
std::make_pair(cast<Instruction>(use.getUser()), use.getOperandNo()));
SmallSet<std::pair<const Instruction *, size_t>, 1> seen;
todo.push_back(std::make_tuple(cast<Instruction>(use.getUser()),
use.getOperandNo(), true));
SmallSet<std::tuple<const Instruction *, size_t, bool>, 1> seen;
while (todo.size()) {
auto pair = todo.back();
auto [cur, idx] = pair;
auto [cur, idx, orig] = pair;
todo.pop_back();
if (seen.count(pair))
continue;
Expand All @@ -9184,6 +9185,8 @@ bool GradientUtils::needsCacheWholeAllocation(
II->getIntrinsicID() == Intrinsic::masked_load)
continue;

bool returnedSameValue = false;

if (auto CI = dyn_cast<CallInst>(cur)) {
#if LLVM_VERSION_MAJOR >= 14
if (idx < CI->arg_size())
Expand All @@ -9193,6 +9196,36 @@ bool GradientUtils::needsCacheWholeAllocation(
{
if (isNoCapture(CI, idx))
continue;

if (auto F = CI->getCalledFunction())
if (F->getCallingConv() == CI->getCallingConv()) {
bool onlyReturnUses = true;
bool hasReturnUse = true;

for (auto u : F->getArg(idx)->users()) {
if (isa<ReturnInst>(u)) {
hasReturnUse = true;
continue;
}
onlyReturnUses = false;
continue;
}
// The arg itself has no use in the function
if (onlyReturnUses && !hasReturnUse)
continue;

// If this is the original allocation, we return it guaranteed, and
// cache the return, that's still fine
if (onlyReturnUses && orig) {
found = knownRecomputeHeuristic.find(cur);
if (found == knownRecomputeHeuristic.end())
continue;

if (!found->second)
continue;
returnedSameValue = true;
}
}
}
}

Expand All @@ -9202,6 +9235,7 @@ bool GradientUtils::needsCacheWholeAllocation(

// If caching this user, it cannot be a gep/cast of original
if (!found->second) {
llvm::errs() << " mod: " << *oldFunc->getParent() << "\n";
llvm::errs() << " oldFunc: " << *oldFunc << "\n";
for (auto &pair : knownRecomputeHeuristic)
llvm::errs() << " krc[" << *pair.first << "] = " << pair.second << "\n";
Expand All @@ -9211,8 +9245,9 @@ bool GradientUtils::needsCacheWholeAllocation(
} else {
// if not caching this user, it is legal to recompute, consider its users
for (auto &use : cur->uses()) {
todo.push_back(std::make_pair(cast<Instruction>(use.getUser()),
use.getOperandNo()));
todo.push_back(std::make_tuple(cast<Instruction>(use.getUser()),
use.getOperandNo(),
returnedSameValue && orig));
}
}
}
Expand Down

0 comments on commit 59ce655

Please sign in to comment.