From 6152dc82a9298d710bc670670610da7a60f2c6e1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 21 Sep 2023 01:13:41 -0400 Subject: [PATCH] Don't clone if empty --- enzyme/Enzyme/DiffeGradientUtils.cpp | 2 ++ enzyme/Enzyme/EnzymeLogic.cpp | 32 ++++++++++++------- enzyme/Enzyme/FunctionUtils.cpp | 26 ++++++++------- enzyme/Enzyme/GradientUtils.cpp | 2 ++ enzyme/Enzyme/TraceUtils.cpp | 12 ++++--- .../test/Integration/ReverseMode/err_empty.c | 2 +- 6 files changed, 47 insertions(+), 29 deletions(-) diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index cdec0bd7c7f4..c88ce960b5df 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -62,6 +62,8 @@ DiffeGradientUtils::DiffeGradientUtils( : GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_, constantvalues_, returnvals_, ActiveReturn, constant_values, origToNew_, mode, width, omp) { + if (oldFunc_->empty()) + return; assert(reverseBlocks.size() == 0); if (mode == DerivativeMode::ForwardMode || mode == DerivativeMode::ForwardModeSplit) { diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 5e2e8b89b816..420939ec2bc8 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2279,16 +2279,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; } if (context.req) { EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, ss.str()); - - if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - report_fatal_error("function failed verification (r5)"); - } auto newFunc = gutils->newFunc; delete gutils; return insert_or_assign( @@ -3938,14 +3939,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient( CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(key.todiff), wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return newFunc; } if (context.req) { - - if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - report_fatal_error("function failed verification (r4)"); - } EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, ss.str()); auto newFunc = gutils->newFunc; @@ -4566,6 +4564,9 @@ Function *EnzymeLogic::CreateForwardDiff( CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return newFunc; } if (context.req) { EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, @@ -4817,6 +4818,7 @@ llvm::Function *EnzymeLogic::CreateBatch(RequestContext context, CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(tobatch), wrap(context.ip)); + return NewF; } if (context.req) { EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, @@ -5112,6 +5114,10 @@ EnzymeLogic::CreateTrace(RequestContext context, llvm::Function *totrace, CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(totrace), wrap(context.ip)); + auto newFunc = tutils->newFunc; + delete tracer; + delete tutils; + return newFunc; } if (context.req) { EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, @@ -5173,6 +5179,7 @@ llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context, CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); + return todiff; } if (context.req) { @@ -5296,6 +5303,7 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(F), wrap(context.ip)); + return F; } if (context.req) { EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 51ece9d7b70d..6189ab1870bc 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -1325,16 +1325,18 @@ Function *PreProcessCache::preprocessForClone(Function *F, SmallVector Returns; + if (!F->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto( - NewF, F, VMap, - /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, Returns, - "", nullptr); + CloneFunctionInto( + NewF, F, VMap, + /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, + Returns, "", nullptr); #else - CloneFunctionInto(NewF, F, VMap, - /*ModuleLevelChanges*/ F->getSubprogram() != nullptr, - Returns, "", nullptr); + CloneFunctionInto(NewF, F, VMap, + /*ModuleLevelChanges*/ F->getSubprogram() != nullptr, + Returns, "", nullptr); #endif + } CloneOrigin[NewF] = F; NewF->setAttributes(F->getAttributes()); if (EnzymeNoAlias) @@ -2113,13 +2115,15 @@ Function *PreProcessCache::CloneFunctionWithReturns( VMap[&I] = &*DestI++; // Add mapping to VMap } SmallVector Returns; + if (!F->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); + CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, + Returns, "", nullptr); #else - CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "", - nullptr); + CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "", + nullptr); #endif + } if (NewF->empty()) { auto entry = BasicBlock::Create(NewF->getContext(), "entry", NewF); IRBuilder<> B(entry); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 683e98d82fc2..7c00ac376bad 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -195,6 +195,8 @@ GradientUtils::GradientUtils( : Logic.PPC.getAAResultsFromFunction(oldFunc_)), TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), overwritten_args_map_ptr(nullptr) { + if (oldFunc_->empty()) + return; if (oldFunc_->getSubprogram()) { assert(originalToNewFn_.hasMD()); } diff --git a/enzyme/Enzyme/TraceUtils.cpp b/enzyme/Enzyme/TraceUtils.cpp index f31cb24c025e..5e7626c4a247 100644 --- a/enzyme/Enzyme/TraceUtils.cpp +++ b/enzyme/Enzyme/TraceUtils.cpp @@ -114,14 +114,16 @@ TraceUtils::FromClone(ProbProgMode mode, } SmallVector Returns; + if (!oldFunc->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto(newFunc, oldFunc, originalToNewFn, - CloneFunctionChangeType::LocalChangesOnly, Returns, "", - nullptr); + CloneFunctionInto(newFunc, oldFunc, originalToNewFn, + CloneFunctionChangeType::LocalChangesOnly, Returns, "", + nullptr); #else - CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "", - nullptr); + CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "", + nullptr); #endif + } if (newFunc->empty()) { auto entry = BasicBlock::Create(newFunc->getContext(), "entry", newFunc); IRBuilder<> B(entry); diff --git a/enzyme/test/Integration/ReverseMode/err_empty.c b/enzyme/test/Integration/ReverseMode/err_empty.c index 213121dd8e0a..899b1cb364af 100644 --- a/enzyme/test/Integration/ReverseMode/err_empty.c +++ b/enzyme/test/Integration/ReverseMode/err_empty.c @@ -12,7 +12,7 @@ extern double __enzyme_autodiff(void*, double); double unknown(double in); double g(double in) { - return unknown(unknown(in)); // expected-error {{Enzyme: No reverse pass found for unknown}} expected-error {{Enzyme: No augmented forward pass found for unknown}} expected-error {{Enzyme: No reverse pass found for unknown at context}} + return unknown(unknown(in)); // expected-error {{Enzyme: No reverse pass found for unknown}} expected-error {{Enzyme: No augmented forward pass found for unknown}} expected-error {{Enzyme: No reverse pass found for unknown}} } double square(double x) {