Skip to content

Commit

Permalink
Fix autodiff ordering with inlining
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 20, 2023
1 parent 8a4532d commit 10a37c8
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 69 deletions.
144 changes: 76 additions & 68 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,8 @@ class EnzymeBase {
Type *retElemType, SmallVectorImpl<Value *> &args,
const std::map<int, Type *> &byVal,
const std::vector<DIFFE_TYPE> &constants, Function *fn,
DerivativeMode mode, Options &options, bool sizeOnly) {
DerivativeMode mode, Options &options, bool sizeOnly,
SmallVectorImpl<CallInst *> &calls) {
auto &differet = options.differet;
auto &tape = options.tape;
auto &width = options.width;
Expand Down Expand Up @@ -1702,63 +1703,13 @@ class EnzymeBase {
}

ReplaceOriginalCall(Builder, ret, retElemType, diffret, CI, mode);

if (Logic.PostOpt) {
auto Params = llvm::getInlineParams();

llvm::SetVector<CallInst *> Q;
Q.insert(diffretc);
while (Q.size()) {
auto cur = *Q.begin();
Function *outerFunc = cur->getParent()->getParent();
llvm::OptimizationRemarkEmitter ORE(outerFunc);
Q.erase(Q.begin());
if (auto F = cur->getCalledFunction()) {
if (!F->empty()) {
// Garbage collect AC's created
SmallVector<AssumptionCache *, 2> ACAlloc;
auto getAC = [&](Function &F) -> llvm::AssumptionCache & {
auto AC = new AssumptionCache(F);
ACAlloc.push_back(AC);
return *AC;
};
auto GetTLI =
[&](llvm::Function &F) -> const llvm::TargetLibraryInfo & {
return Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
};

auto GetInlineCost = [&](CallBase &CB) {
TargetTransformInfo TTI(F->getParent()->getDataLayout());
auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI);
return cst;
};
if (llvm::shouldInline(*cur, GetInlineCost, ORE)) {
InlineFunctionInfo IFI;
InlineResult IR = InlineFunction(*cur, IFI);
if (IR.isSuccess()) {
LowerSparsification(outerFunc, /*replaceAll*/ false);
for (auto U : outerFunc->users()) {
if (auto CI = dyn_cast<CallInst>(U)) {
if (CI->getCalledFunction() == outerFunc) {
Q.insert(CI);
}
}
}
}
}
for (auto AC : ACAlloc) {
delete AC;
}
}
}
}
}
return true;
calls.push_back(diffretc);
return diffret;
}

/// Return whether successful
bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode,
bool sizeOnly) {
bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, bool sizeOnly,
SmallVectorImpl<CallInst *> &calls) {

// determine function to differentiate
Function *fn = parseFunctionParameter(CI);
Expand Down Expand Up @@ -1796,16 +1747,17 @@ class EnzymeBase {

#if LLVM_VERSION_MAJOR >= 16
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
byVal, constants, fn, mode, options.value(),
sizeOnly);
byVal, constants, fn, mode, options.value(), sizeOnly,
calls);
#else
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
byVal, constants, fn, mode, options.getValue(),
sizeOnly);
sizeOnly, calls);
#endif
}

bool HandleProbProg(CallInst *CI, ProbProgMode mode) {
bool HandleProbProg(CallInst *CI, ProbProgMode mode,
SmallVectorImpl<CallInst *> &calls) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI);
if (!F)
Expand Down Expand Up @@ -1928,13 +1880,15 @@ class EnzymeBase {
}

#if LLVM_VERSION_MAJOR >= 16
bool status = HandleAutoDiff(
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
newFunc, DerivativeMode::ReverseModeCombined, opt.value(), false);
bool status =
HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal,
constants, newFunc, DerivativeMode::ReverseModeCombined,
opt.value(), false, calls);
#else
bool status = HandleAutoDiff(
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
newFunc, DerivativeMode::ReverseModeCombined, opt.getValue(), false);
bool status =
HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal,
constants, newFunc, DerivativeMode::ReverseModeCombined,
opt.getValue(), false, calls);
#endif

delete interface;
Expand Down Expand Up @@ -2447,17 +2401,19 @@ class EnzymeBase {
Changed = true;
}

SmallVector<CallInst *, 1> calls;

// Perform all the size replacements first to create constants
for (auto pair : toSize) {
bool successful = HandleAutoDiffArguments(pair.first, pair.second,
/*sizeOnly*/ true);
/*sizeOnly*/ true, calls);
Changed = true;
if (!successful)
break;
}
for (auto pair : toLower) {
bool successful = HandleAutoDiffArguments(pair.first, pair.second,
/*sizeOnly*/ false);
/*sizeOnly*/ false, calls);
Changed = true;
if (!successful)
break;
Expand Down Expand Up @@ -2495,7 +2451,59 @@ class EnzymeBase {
}

for (auto &&[call, mode] : toProbProg) {
HandleProbProg(call, mode);
HandleProbProg(call, mode, calls);
}

if (Logic.PostOpt) {
auto Params = llvm::getInlineParams();

llvm::SetVector<CallInst *> Q;
for (auto call : calls)
Q.insert(call);
while (Q.size()) {
auto cur = *Q.begin();
Function *outerFunc = cur->getParent()->getParent();
llvm::OptimizationRemarkEmitter ORE(outerFunc);
Q.erase(Q.begin());
if (auto F = cur->getCalledFunction()) {
if (!F->empty()) {
// Garbage collect AC's created
SmallVector<AssumptionCache *, 2> ACAlloc;
auto getAC = [&](Function &F) -> llvm::AssumptionCache & {
auto AC = new AssumptionCache(F);
ACAlloc.push_back(AC);
return *AC;
};
auto GetTLI =
[&](llvm::Function &F) -> const llvm::TargetLibraryInfo & {
return Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
};

auto GetInlineCost = [&](CallBase &CB) {
TargetTransformInfo TTI(F->getParent()->getDataLayout());
auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI);
return cst;
};
if (llvm::shouldInline(*cur, GetInlineCost, ORE)) {
InlineFunctionInfo IFI;
InlineResult IR = InlineFunction(*cur, IFI);
if (IR.isSuccess()) {
LowerSparsification(outerFunc, /*replaceAll*/ false);
for (auto U : outerFunc->users()) {
if (auto CI = dyn_cast<CallInst>(U)) {
if (CI->getCalledFunction() == outerFunc) {
Q.insert(CI);
}
}
}
}
}
for (auto AC : ACAlloc) {
delete AC;
}
}
}
}
}

if (Changed && EnzymeAttributor) {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Integration/ReverseMode/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,6 @@ int main() {

gemvTests();

gemmTests();
// gemmTests();

}

0 comments on commit 10a37c8

Please sign in to comment.