From 4d5344841e318a194117f35da90bb4d2156c3cf1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 21 Sep 2023 01:13:41 -0400 Subject: [PATCH] Don't clone if empty --- enzyme/Enzyme/EnzymeLogic.cpp | 32 ++++++++++++++++++++------------ enzyme/Enzyme/FunctionUtils.cpp | 26 +++++++++++++++----------- enzyme/Enzyme/TraceUtils.cpp | 12 +++++++----- 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 5e2e8b89b816..420939ec2bc8 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2279,16 +2279,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; } if (context.req) { EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, ss.str()); - - if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - report_fatal_error("function failed verification (r5)"); - } auto newFunc = gutils->newFunc; delete gutils; return insert_or_assign( @@ -3938,14 +3939,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient( CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(key.todiff), wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return newFunc; } if (context.req) { - - if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { - llvm::errs() << *gutils->oldFunc << "\n"; - llvm::errs() << *gutils->newFunc << "\n"; - report_fatal_error("function failed verification (r4)"); - } EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, ss.str()); auto newFunc = gutils->newFunc; @@ -4566,6 +4564,9 @@ Function *EnzymeLogic::CreateForwardDiff( CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return newFunc; } if (context.req) { EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, @@ -4817,6 +4818,7 @@ llvm::Function *EnzymeLogic::CreateBatch(RequestContext context, CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(tobatch), wrap(context.ip)); + return NewF; } if (context.req) { EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, @@ -5112,6 +5114,10 @@ EnzymeLogic::CreateTrace(RequestContext context, llvm::Function *totrace, CustomErrorHandler(ss.str().c_str(), wrap(toshow), ErrorType::NoDerivative, nullptr, wrap(totrace), wrap(context.ip)); + auto newFunc = tutils->newFunc; + delete tracer; + delete tutils; + return newFunc; } if (context.req) { EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, @@ -5173,6 +5179,7 @@ llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context, CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); + return todiff; } if (context.req) { @@ -5296,6 +5303,7 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(F), wrap(context.ip)); + return F; } if (context.req) { EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 51ece9d7b70d..6189ab1870bc 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -1325,16 +1325,18 @@ Function *PreProcessCache::preprocessForClone(Function *F, SmallVector Returns; + if (!F->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto( - NewF, F, VMap, - /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, Returns, - "", nullptr); + CloneFunctionInto( + NewF, F, VMap, + /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, + Returns, "", nullptr); #else - CloneFunctionInto(NewF, F, VMap, - /*ModuleLevelChanges*/ F->getSubprogram() != nullptr, - Returns, "", nullptr); + CloneFunctionInto(NewF, F, VMap, + /*ModuleLevelChanges*/ F->getSubprogram() != nullptr, + Returns, "", nullptr); #endif + } CloneOrigin[NewF] = F; NewF->setAttributes(F->getAttributes()); if (EnzymeNoAlias) @@ -2113,13 +2115,15 @@ Function *PreProcessCache::CloneFunctionWithReturns( VMap[&I] = &*DestI++; // Add mapping to VMap } SmallVector Returns; + if (!F->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); + CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, + Returns, "", nullptr); #else - CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "", - nullptr); + CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "", + nullptr); #endif + } if (NewF->empty()) { auto entry = BasicBlock::Create(NewF->getContext(), "entry", NewF); IRBuilder<> B(entry); diff --git a/enzyme/Enzyme/TraceUtils.cpp b/enzyme/Enzyme/TraceUtils.cpp index f31cb24c025e..5e7626c4a247 100644 --- a/enzyme/Enzyme/TraceUtils.cpp +++ b/enzyme/Enzyme/TraceUtils.cpp @@ -114,14 +114,16 @@ TraceUtils::FromClone(ProbProgMode mode, } SmallVector Returns; + if (!oldFunc->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto(newFunc, oldFunc, originalToNewFn, - CloneFunctionChangeType::LocalChangesOnly, Returns, "", - nullptr); + CloneFunctionInto(newFunc, oldFunc, originalToNewFn, + CloneFunctionChangeType::LocalChangesOnly, Returns, "", + nullptr); #else - CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "", - nullptr); + CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "", + nullptr); #endif + } if (newFunc->empty()) { auto entry = BasicBlock::Create(newFunc->getContext(), "entry", newFunc); IRBuilder<> B(entry);