Skip to content

Commit

Permalink
[WIP] Auto truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 15, 2023
1 parent 7284903 commit 47ca979
Show file tree
Hide file tree
Showing 7 changed files with 579 additions and 1 deletion.
37 changes: 37 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

//===- Enzyme.cpp - Automatic Differentiation Transformation Pass -------===//
//
// Enzyme Project
Expand Down Expand Up @@ -1314,6 +1315,31 @@ class EnzymeBase {
return type_args;
}

bool HandleTruncate(CallInst *CI) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI);
if (!F)
return false;
if (CI->arg_size() != 3) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args to __enzyme_truncate", *CI,
" - expected 3");
return false;
}
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
RequestContext context(CI, &Builder);
llvm::Value* res = Logic.CreateTruncate(context, F, (unsigned)Cfrom->getValue().getZExtValue(), (unsigned)Cto->getValue().getZExtValue() );
if (!res)
return false;
res = Builder.CreatePointerCast(res, CI->getType());
CI->replaceAllUsesWith(res);
CI->eraseFromParent();
return true;
}

bool HandleBatch(CallInst *CI) {
unsigned width = 1;
unsigned truei = 0;
Expand Down Expand Up @@ -2028,6 +2054,7 @@ class EnzymeBase {
Fn->getName().contains("__enzyme_augmentfwd") ||
Fn->getName().contains("__enzyme_augmentsize") ||
Fn->getName().contains("__enzyme_reverse") ||
Fn->getName().contains("__enzyme_truncate") ||
Fn->getName().contains("__enzyme_batch") ||
Fn->getName().contains("__enzyme_trace") ||
Fn->getName().contains("__enzyme_condition")))
Expand Down Expand Up @@ -2060,6 +2087,7 @@ class EnzymeBase {
MapVector<CallInst *, DerivativeMode> toVirtual;
MapVector<CallInst *, DerivativeMode> toSize;
SmallVector<CallInst *, 4> toBatch;
SmallVector<CallInst *, 4> toTruncate;
MapVector<CallInst *, ProbProgMode> toProbProg;
SetVector<CallInst *> InactiveCalls;
SetVector<CallInst *> IterCalls;
Expand Down Expand Up @@ -2369,6 +2397,7 @@ class EnzymeBase {
bool virtualCall = false;
bool sizeOnly = false;
bool batch = false;
bool truncate = false;
bool probProg = false;
DerivativeMode derivativeMode;
ProbProgMode probProgMode;
Expand Down Expand Up @@ -2398,6 +2427,9 @@ class EnzymeBase {
} else if (Fn->getName().contains("__enzyme_batch")) {
enableEnzyme = true;
batch = true;
} else if (Fn->getName().contains("__enzyme_truncate")) {
enableEnzyme = true;
truncate = true;
} else if (Fn->getName().contains("__enzyme_likelihood")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Likelihood;
Expand Down Expand Up @@ -2455,6 +2487,8 @@ class EnzymeBase {
toSize[CI] = derivativeMode;
else if (batch)
toBatch.push_back(CI);
else if (truncate)
toTruncate.push_back(CI);
else if (probProg) {
toProbProg[CI] = probProgMode;
} else
Expand Down Expand Up @@ -2548,6 +2582,9 @@ class EnzymeBase {
for (auto call : toBatch) {
HandleBatch(call);
}
for (auto call : toTruncate) {
HandleTruncate(call);
}

for (auto &&[call, mode] : toProbProg) {
HandleProbProg(call, mode, calls);
Expand Down
Loading

0 comments on commit 47ca979

Please sign in to comment.