Skip to content

Commit

Permalink
Merge branch 'main' into wsmoses-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 16, 2024
2 parents e9b45b7 + 9b979e3 commit 10c82a7
Show file tree
Hide file tree
Showing 29 changed files with 3,723 additions and 870 deletions.
14 changes: 14 additions & 0 deletions enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,20 @@ void RemoveRedundantIVs(
// and must thus be expanded after all phi's
Value *NewIV =
Exp.expandCodeFor(S, Tmp->getType(), Header->getFirstNonPHI());

// Explicity preserve wrap behavior from original iv. This is necessary
// until this PR in llvm is merged:
// https://github.com/llvm/llvm-project/pull/78199
if (auto addrec = dyn_cast<SCEVAddRecExpr>(S)) {
if (addrec->getLoop()->getHeader() == Header) {
if (auto add_or_mul = dyn_cast<BinaryOperator>(NewIV)) {
if (addrec->getNoWrapFlags(llvm::SCEV::FlagNUW))
add_or_mul->setHasNoUnsignedWrap(true);
if (addrec->getNoWrapFlags(llvm::SCEV::FlagNSW))
add_or_mul->setHasNoSignedWrap(true);
}
}
}
replacer(Tmp, NewIV);
eraser(Tmp);
}
Expand Down
38 changes: 38 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,33 @@ class EnzymeBase {
return type_args;
}

bool HandleTruncate(CallInst *CI) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI);
if (!F)
return false;
if (CI->arg_size() != 3) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args to __enzyme_truncate", *CI,
" - expected 3");
return false;
}
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
RequestContext context(CI, &Builder);
llvm::Value *res = Logic.CreateTruncate(
context, F, (unsigned)Cfrom->getValue().getZExtValue(),
(unsigned)Cto->getValue().getZExtValue());
if (!res)
return false;
res = Builder.CreatePointerCast(res, CI->getType());
CI->replaceAllUsesWith(res);
CI->eraseFromParent();
return true;
}

bool HandleBatch(CallInst *CI) {
unsigned width = 1;
unsigned truei = 0;
Expand Down Expand Up @@ -2028,6 +2055,7 @@ class EnzymeBase {
Fn->getName().contains("__enzyme_augmentfwd") ||
Fn->getName().contains("__enzyme_augmentsize") ||
Fn->getName().contains("__enzyme_reverse") ||
Fn->getName().contains("__enzyme_truncate") ||
Fn->getName().contains("__enzyme_batch") ||
Fn->getName().contains("__enzyme_trace") ||
Fn->getName().contains("__enzyme_condition")))
Expand Down Expand Up @@ -2060,6 +2088,7 @@ class EnzymeBase {
MapVector<CallInst *, DerivativeMode> toVirtual;
MapVector<CallInst *, DerivativeMode> toSize;
SmallVector<CallInst *, 4> toBatch;
SmallVector<CallInst *, 4> toTruncate;
MapVector<CallInst *, ProbProgMode> toProbProg;
SetVector<CallInst *> InactiveCalls;
SetVector<CallInst *> IterCalls;
Expand Down Expand Up @@ -2369,6 +2398,7 @@ class EnzymeBase {
bool virtualCall = false;
bool sizeOnly = false;
bool batch = false;
bool truncate = false;
bool probProg = false;
DerivativeMode derivativeMode;
ProbProgMode probProgMode;
Expand Down Expand Up @@ -2398,6 +2428,9 @@ class EnzymeBase {
} else if (Fn->getName().contains("__enzyme_batch")) {
enableEnzyme = true;
batch = true;
} else if (Fn->getName().contains("__enzyme_truncate")) {
enableEnzyme = true;
truncate = true;
} else if (Fn->getName().contains("__enzyme_likelihood")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Likelihood;
Expand Down Expand Up @@ -2455,6 +2488,8 @@ class EnzymeBase {
toSize[CI] = derivativeMode;
else if (batch)
toBatch.push_back(CI);
else if (truncate)
toTruncate.push_back(CI);
else if (probProg) {
toProbProg[CI] = probProgMode;
} else
Expand Down Expand Up @@ -2548,6 +2583,9 @@ class EnzymeBase {
for (auto call : toBatch) {
HandleBatch(call);
}
for (auto call : toTruncate) {
HandleTruncate(call);
}

for (auto &&[call, mode] : toProbProg) {
HandleProbProg(call, mode, calls);
Expand Down
Loading

0 comments on commit 10c82a7

Please sign in to comment.