Skip to content

Commit

Permalink
Don't clone if empty
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 21, 2023
1 parent e0737ec commit 6152dc8
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 29 deletions.
2 changes: 2 additions & 0 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ DiffeGradientUtils::DiffeGradientUtils(
: GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_,
constantvalues_, returnvals_, ActiveReturn, constant_values,
origToNew_, mode, width, omp) {
if (oldFunc_->empty())
return;
assert(reverseBlocks.size() == 0);
if (mode == DerivativeMode::ForwardMode ||
mode == DerivativeMode::ForwardModeSplit) {
Expand Down
32 changes: 20 additions & 12 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2279,16 +2279,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(todiff),
wrap(context.ip));
auto newFunc = gutils->newFunc;
delete gutils;
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {},
constant_args))
->second;
}
if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
ss.str());

if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) {
llvm::errs() << *gutils->oldFunc << "\n";
llvm::errs() << *gutils->newFunc << "\n";
report_fatal_error("function failed verification (r5)");
}
auto newFunc = gutils->newFunc;
delete gutils;
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
Expand Down Expand Up @@ -3938,14 +3939,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(key.todiff),
wrap(context.ip));
auto newFunc = gutils->newFunc;
delete gutils;
return newFunc;
}
if (context.req) {

if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) {
llvm::errs() << *gutils->oldFunc << "\n";
llvm::errs() << *gutils->newFunc << "\n";
report_fatal_error("function failed verification (r4)");
}
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
ss.str());
auto newFunc = gutils->newFunc;
Expand Down Expand Up @@ -4566,6 +4564,9 @@ Function *EnzymeLogic::CreateForwardDiff(
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(todiff),
wrap(context.ip));
auto newFunc = gutils->newFunc;
delete gutils;
return newFunc;
}
if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
Expand Down Expand Up @@ -4817,6 +4818,7 @@ llvm::Function *EnzymeLogic::CreateBatch(RequestContext context,
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(tobatch),
wrap(context.ip));
return NewF;
}
if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
Expand Down Expand Up @@ -5112,6 +5114,10 @@ EnzymeLogic::CreateTrace(RequestContext context, llvm::Function *totrace,
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(totrace),
wrap(context.ip));
auto newFunc = tutils->newFunc;
delete tracer;
delete tutils;
return newFunc;
}
if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
Expand Down Expand Up @@ -5173,6 +5179,7 @@ llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context,
CustomErrorHandler(ss.str().c_str(), wrap(context.req),
ErrorType::NoDerivative, nullptr, wrap(todiff),
wrap(context.ip));
return todiff;
}

if (context.req) {
Expand Down Expand Up @@ -5296,6 +5303,7 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) {
CustomErrorHandler(ss.str().c_str(), wrap(context.req),
ErrorType::NoDerivative, nullptr, wrap(F),
wrap(context.ip));
return F;
}
if (context.req) {
EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req,
Expand Down
26 changes: 15 additions & 11 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1325,16 +1325,18 @@ Function *PreProcessCache::preprocessForClone(Function *F,

SmallVector<ReturnInst *, 4> Returns;

if (!F->empty()) {
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(
NewF, F, VMap,
/*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, Returns,
"", nullptr);
CloneFunctionInto(
NewF, F, VMap,
/*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly,
Returns, "", nullptr);
#else
CloneFunctionInto(NewF, F, VMap,
/*ModuleLevelChanges*/ F->getSubprogram() != nullptr,
Returns, "", nullptr);
CloneFunctionInto(NewF, F, VMap,
/*ModuleLevelChanges*/ F->getSubprogram() != nullptr,
Returns, "", nullptr);
#endif
}
CloneOrigin[NewF] = F;
NewF->setAttributes(F->getAttributes());
if (EnzymeNoAlias)
Expand Down Expand Up @@ -2113,13 +2115,15 @@ Function *PreProcessCache::CloneFunctionWithReturns(
VMap[&I] = &*DestI++; // Add mapping to VMap
}
SmallVector<ReturnInst *, 4> Returns;
if (!F->empty()) {
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
Returns, "", nullptr);
CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
Returns, "", nullptr);
#else
CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "",
nullptr);
CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "",
nullptr);
#endif
}
if (NewF->empty()) {
auto entry = BasicBlock::Create(NewF->getContext(), "entry", NewF);
IRBuilder<> B(entry);
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ GradientUtils::GradientUtils(
: Logic.PPC.getAAResultsFromFunction(oldFunc_)),
TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_),
overwritten_args_map_ptr(nullptr) {
if (oldFunc_->empty())
return;
if (oldFunc_->getSubprogram()) {
assert(originalToNewFn_.hasMD());
}
Expand Down
12 changes: 7 additions & 5 deletions enzyme/Enzyme/TraceUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,16 @@ TraceUtils::FromClone(ProbProgMode mode,
}

SmallVector<ReturnInst *, 4> Returns;
if (!oldFunc->empty()) {
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(newFunc, oldFunc, originalToNewFn,
CloneFunctionChangeType::LocalChangesOnly, Returns, "",
nullptr);
CloneFunctionInto(newFunc, oldFunc, originalToNewFn,
CloneFunctionChangeType::LocalChangesOnly, Returns, "",
nullptr);
#else
CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "",
nullptr);
CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "",
nullptr);
#endif
}
if (newFunc->empty()) {
auto entry = BasicBlock::Create(newFunc->getContext(), "entry", newFunc);
IRBuilder<> B(entry);
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Integration/ReverseMode/err_empty.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extern double __enzyme_autodiff(void*, double);
double unknown(double in);

double g(double in) {
return unknown(unknown(in)); // expected-error {{Enzyme: No reverse pass found for unknown}} expected-error {{Enzyme: No augmented forward pass found for unknown}} expected-error {{Enzyme: No reverse pass found for unknown at context}}
return unknown(unknown(in)); // expected-error {{Enzyme: No reverse pass found for unknown}} expected-error {{Enzyme: No augmented forward pass found for unknown}} expected-error {{Enzyme: No reverse pass found for unknown}}
}

double square(double x) {
Expand Down

0 comments on commit 6152dc8

Please sign in to comment.