Skip to content

Commit

Permalink
Don't clone if empty
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 21, 2023
1 parent e0737ec commit 4d53448
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 28 deletions.
32 changes: 20 additions & 12 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2279,16 +2279,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(todiff),
wrap(context.ip));
auto newFunc = gutils->newFunc;
delete gutils;
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {},
constant_args))
->second;
}
if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
ss.str());

if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) {
llvm::errs() << *gutils->oldFunc << "\n";
llvm::errs() << *gutils->newFunc << "\n";
report_fatal_error("function failed verification (r5)");
}
auto newFunc = gutils->newFunc;
delete gutils;
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
Expand Down Expand Up @@ -3938,14 +3939,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(key.todiff),
wrap(context.ip));
auto newFunc = gutils->newFunc;
delete gutils;
return newFunc;
}
if (context.req) {

if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) {
llvm::errs() << *gutils->oldFunc << "\n";
llvm::errs() << *gutils->newFunc << "\n";
report_fatal_error("function failed verification (r4)");
}
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
ss.str());
auto newFunc = gutils->newFunc;
Expand Down Expand Up @@ -4566,6 +4564,9 @@ Function *EnzymeLogic::CreateForwardDiff(
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(todiff),
wrap(context.ip));
auto newFunc = gutils->newFunc;
delete gutils;
return newFunc;
}
if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
Expand Down Expand Up @@ -4817,6 +4818,7 @@ llvm::Function *EnzymeLogic::CreateBatch(RequestContext context,
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(tobatch),
wrap(context.ip));
return NewF;
}
if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
Expand Down Expand Up @@ -5112,6 +5114,10 @@ EnzymeLogic::CreateTrace(RequestContext context, llvm::Function *totrace,
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(totrace),
wrap(context.ip));
auto newFunc = tutils->newFunc;
delete tracer;
delete tutils;
return newFunc;
}
if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
Expand Down Expand Up @@ -5173,6 +5179,7 @@ llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context,
CustomErrorHandler(ss.str().c_str(), wrap(context.req),
ErrorType::NoDerivative, nullptr, wrap(todiff),
wrap(context.ip));
return todiff;
}

if (context.req) {
Expand Down Expand Up @@ -5296,6 +5303,7 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) {
CustomErrorHandler(ss.str().c_str(), wrap(context.req),
ErrorType::NoDerivative, nullptr, wrap(F),
wrap(context.ip));
return F;
}
if (context.req) {
EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req,
Expand Down
26 changes: 15 additions & 11 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1325,16 +1325,18 @@ Function *PreProcessCache::preprocessForClone(Function *F,

SmallVector<ReturnInst *, 4> Returns;

if (!F->empty()) {
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(
NewF, F, VMap,
/*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, Returns,
"", nullptr);
CloneFunctionInto(
NewF, F, VMap,
/*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly,
Returns, "", nullptr);
#else
CloneFunctionInto(NewF, F, VMap,
/*ModuleLevelChanges*/ F->getSubprogram() != nullptr,
Returns, "", nullptr);
CloneFunctionInto(NewF, F, VMap,
/*ModuleLevelChanges*/ F->getSubprogram() != nullptr,
Returns, "", nullptr);
#endif
}
CloneOrigin[NewF] = F;
NewF->setAttributes(F->getAttributes());
if (EnzymeNoAlias)
Expand Down Expand Up @@ -2113,13 +2115,15 @@ Function *PreProcessCache::CloneFunctionWithReturns(
VMap[&I] = &*DestI++; // Add mapping to VMap
}
SmallVector<ReturnInst *, 4> Returns;
if (!F->empty()) {
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
Returns, "", nullptr);
CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
Returns, "", nullptr);
#else
CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "",
nullptr);
CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "",
nullptr);
#endif
}
if (NewF->empty()) {
auto entry = BasicBlock::Create(NewF->getContext(), "entry", NewF);
IRBuilder<> B(entry);
Expand Down
12 changes: 7 additions & 5 deletions enzyme/Enzyme/TraceUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,16 @@ TraceUtils::FromClone(ProbProgMode mode,
}

SmallVector<ReturnInst *, 4> Returns;
if (!oldFunc->empty()) {
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(newFunc, oldFunc, originalToNewFn,
CloneFunctionChangeType::LocalChangesOnly, Returns, "",
nullptr);
CloneFunctionInto(newFunc, oldFunc, originalToNewFn,
CloneFunctionChangeType::LocalChangesOnly, Returns, "",
nullptr);
#else
CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "",
nullptr);
CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "",
nullptr);
#endif
}
if (newFunc->empty()) {
auto entry = BasicBlock::Create(newFunc->getContext(), "entry", newFunc);
IRBuilder<> B(entry);
Expand Down

0 comments on commit 4d53448

Please sign in to comment.