From 693140b43625cfe5dbdd1132500b7dff5d5a7662 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 13 Dec 2023 19:39:20 -0500 Subject: [PATCH 01/13] [WIP] Auto truncation --- enzyme/Enzyme/Enzyme.cpp | 37 ++ enzyme/Enzyme/EnzymeLogic.cpp | 479 +++++++++++++++++++++ enzyme/Enzyme/EnzymeLogic.h | 5 + enzyme/Enzyme/Utils.h | 1 + enzyme/test/Enzyme/CMakeLists.txt | 1 + enzyme/test/Enzyme/Truncate/CMakeLists.txt | 12 + enzyme/test/Enzyme/Truncate/simple.ll | 43 ++ 7 files changed, 578 insertions(+) create mode 100644 enzyme/test/Enzyme/Truncate/CMakeLists.txt create mode 100644 enzyme/test/Enzyme/Truncate/simple.ll diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index d09682c0aff7..d4435b483682 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1,3 +1,4 @@ + //===- Enzyme.cpp - Automatic Differentiation Transformation Pass -------===// // // Enzyme Project @@ -1314,6 +1315,31 @@ class EnzymeBase { return type_args; } + bool HandleTruncate(CallInst *CI) { + IRBuilder<> Builder(CI); + Function *F = parseFunctionParameter(CI); + if (!F) + return false; + if (CI->arg_size() != 3) { + EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, + "Had incorrect number of args to __enzyme_truncate", *CI, + " - expected 3"); + return false; + } + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto = cast(CI->getArgOperand(2)); + assert(Cto); + RequestContext context(CI, &Builder); + llvm::Value* res = Logic.CreateTruncate(context, F, (unsigned)Cfrom->getValue().getZExtValue(), (unsigned)Cto->getValue().getZExtValue() ); + if (!res) + return false; + res = Builder.CreatePointerCast(res, CI->getType()); + CI->replaceAllUsesWith(res); + CI->eraseFromParent(); + return true; + } + bool HandleBatch(CallInst *CI) { unsigned width = 1; unsigned truei = 0; @@ -2028,6 +2054,7 @@ class EnzymeBase { Fn->getName().contains("__enzyme_augmentfwd") || Fn->getName().contains("__enzyme_augmentsize") || Fn->getName().contains("__enzyme_reverse") || + Fn->getName().contains("__enzyme_truncate") || Fn->getName().contains("__enzyme_batch") || Fn->getName().contains("__enzyme_trace") || Fn->getName().contains("__enzyme_condition"))) @@ -2060,6 +2087,7 @@ class EnzymeBase { MapVector toVirtual; MapVector toSize; SmallVector toBatch; + SmallVector toTruncate; MapVector toProbProg; SetVector InactiveCalls; SetVector IterCalls; @@ -2369,6 +2397,7 @@ class EnzymeBase { bool virtualCall = false; bool sizeOnly = false; bool batch = false; + bool truncate = false; bool probProg = false; DerivativeMode derivativeMode; ProbProgMode probProgMode; @@ -2398,6 +2427,9 @@ class EnzymeBase { } else if (Fn->getName().contains("__enzyme_batch")) { enableEnzyme = true; batch = true; + } else if (Fn->getName().contains("__enzyme_truncate")) { + enableEnzyme = true; + truncate = true; } else if (Fn->getName().contains("__enzyme_likelihood")) { enableEnzyme = true; probProgMode = ProbProgMode::Likelihood; @@ -2455,6 +2487,8 @@ class EnzymeBase { toSize[CI] = derivativeMode; else if (batch) toBatch.push_back(CI); + else if (truncate) + toTruncate.push_back(CI); else if (probProg) { toProbProg[CI] = probProgMode; } else @@ -2548,6 +2582,9 @@ class EnzymeBase { for (auto call : toBatch) { HandleBatch(call); } + for (auto call : toTruncate) { + HandleTruncate(call); + } for (auto &&[call, mode] : toProbProg) { HandleProbProg(call, mode, calls); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 57aef11c087f..69ff0860bf17 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4815,6 +4815,485 @@ Function *EnzymeLogic::CreateForwardDiff( return nf; } +class TruncateGenerator : public llvm::InstVisitor { +private: +ValueToValueMapTy &originalToNewFn; +unsigned fromwidth; +unsigned towidth; +Function* oldFunc; +Function* newFunc; +AllocaInst* tmpBlock; +EnzymeLogic &Logic; + +public: +TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsigned towidth, Function* oldFunc, Function* newFunc, EnzymeLogic& Logic) : + originalToNewFn(originalToNewFn), fromwidth(fromwidth), towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) { + IRBuilder <> B(&newFunc->getEntryBlock().front()); + tmpBlock = B.CreateAlloca(getTypeForWidth(fromwidth)); + } + + void visitInstruction(llvm::Instruction &inst) { + using namespace llvm; + + // TODO explicitly handle all instructions rather than using the catch all + // below + + switch (inst.getOpcode()) { +//#include "InstructionDerivatives.inc" + default: + break; + } + + todo(inst); + } + + Type* getTypeForWidth(unsigned width) { + switch(width){ + default: + return llvm::Type::getIntNTy(oldFunc->getContext(), width); + case 64: + return llvm::Type::getDoubleTy(oldFunc->getContext()); + case 32: + return llvm::Type::getFloatTy(oldFunc->getContext()); + case 16: + return llvm::Type::getHalfTy(oldFunc->getContext()); + } + } + Value *truncate(IRBuilder<> &B, Value* v) { + Type* nextType = getTypeForWidth(towidth); + B.CreateStore(v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); + return B.CreateLoad(nextType, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(nextType))); + } + + Value *expand(IRBuilder<> &B, Value* v, Type* origT) { + auto c0 = Constant::getNullValue(llvm::Type::getIntNTy(oldFunc->getContext(), fromwidth)); + B.CreateStore(c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); + B.CreateStore(v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); + return B.CreateLoad(origT, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(origT))); + } + + void todo(llvm::Instruction &I) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "cannot handle unknown instruction\n" << I; + if (CustomErrorHandler) { + IRBuilder<> Builder2(getNewFromOriginal(&I)); + CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate, + this, nullptr, wrap(&Builder2)); + return; + } else { + EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str()); + return; + } + } + + void visitAllocaInst(llvm::AllocaInst &I) { + return; + } + void visitICmpInst(llvm::ICmpInst &I) { + return; + } + void visitFCmpInst(llvm::FCmpInst &I) { + todo(I); + return; + } + void visitLoadInst(llvm::LoadInst &LI) { + auto alignment = LI.getAlign(); + visitLoadLike(LI, alignment); + } + void visitStoreInst(llvm::StoreInst &SI) { + auto align = SI.getAlign(); + visitCommonStore(SI, SI.getPointerOperand(), SI.getValueOperand(), align, + SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), + /*mask=*/nullptr); + } + void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { + return; + } + void visitPHINode(llvm::PHINode &phi) { + return; + } + void visitCastInst(llvm::CastInst &phi) { + todo(phi); + return; + } + void visitSelectInst(llvm::SelectInst &SI) { + todo(SI); + return; + } + void visitExtractElementInst(llvm::ExtractElementInst &EEI) { + return; + } + void visitInsertElementInst(llvm::InsertElementInst &EEI) { + return; + } + void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { + return; + } + void visitExtractValueInst(llvm::ExtractValueInst &EEI) { + return; + } + void visitInsertValueInst(llvm::InsertValueInst &EEI) { + return; + } + void visitBinaryOperator(llvm::BinaryOperator &BO) { + + switch(BO.getOpcode()) { + default: break; + case BinaryOperator::Add: + case BinaryOperator::Sub: + case BinaryOperator::Mul: + case BinaryOperator::UDiv: + case BinaryOperator::SDiv: + case BinaryOperator::URem: + case BinaryOperator::SRem: + case BinaryOperator::AShr: + case BinaryOperator::LShr: + case BinaryOperator::Shl: + case BinaryOperator::And: + case BinaryOperator::Or: + case BinaryOperator::Xor: + return; + } + + if (towidth == 32 || towidth == 16 || towidth == 64) { + auto newI = getNewFromOriginal(&BO); + IRBuilder<> B(newI); + switch(BO.getOpcode()) { + default: break; + case BinaryOperator::FMul: + { + auto nres = cast(B.CreateFMul(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->eraseFromParent(); + } + return; + case BinaryOperator::FAdd: + { + auto nres = cast(B.CreateFAdd(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->eraseFromParent(); + } + return; + case BinaryOperator::FSub: + { + auto nres = cast(B.CreateFSub(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->eraseFromParent(); + } + return; + case BinaryOperator::FDiv: + { + auto nres = cast(B.CreateFDiv(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->eraseFromParent(); + } + return; + case BinaryOperator::FRem: + { + auto nres = cast(B.CreateFRem(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->eraseFromParent(); + } + return; + } + } + todo(BO); + return; + } + void visitMemSetInst(llvm::MemSetInst &MS) { + visitMemSetCommon(MS); + } + void visitMemSetCommon(llvm::CallInst &MS) { + return; + } + void visitMemTransferInst(llvm::MemTransferInst &MTI) { + using namespace llvm; + Value *isVolatile = getNewFromOriginal(MTI.getOperand(3)); + auto srcAlign = MTI.getSourceAlign(); + auto dstAlign = MTI.getDestAlign(); + visitMemTransferCommon(MTI.getIntrinsicID(), srcAlign, dstAlign, MTI, + MTI.getOperand(0), MTI.getOperand(1), + getNewFromOriginal(MTI.getOperand(2)), + isVolatile); + } + void visitMemTransferCommon(llvm::Intrinsic::ID ID, llvm::MaybeAlign srcAlign, + llvm::MaybeAlign dstAlign, llvm::CallInst &MTI, + llvm::Value *orig_dst, llvm::Value *orig_src, + llvm::Value *new_size, llvm::Value *isVolatile) { + return; + } + void visitFenceInst(llvm::FenceInst &FI) { + return; + } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { + SmallVector orig_ops(II.getNumOperands()); + for (unsigned i = 0; i < II.getNumOperands(); ++i) { + orig_ops[i] = II.getOperand(i); + } + if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops)) + return; + todo(II); + return; + } + + void visitReturnInst(llvm::ReturnInst &I) { + return; + } + + void visitBranchInst(llvm::BranchInst &I) { + return; + } + void visitSwitchInst(llvm::SwitchInst &I) { + return; + } + void visitUnreachableInst(llvm::UnreachableInst &I) { + return; + } + void visitLoadLike(llvm::Instruction &I, llvm::MaybeAlign alignment, + llvm::Value *mask = nullptr, + llvm::Value *orig_maskInit = nullptr) { + return; + } + + void visitCommonStore(llvm::Instruction &I, llvm::Value *orig_ptr, + llvm::Value *orig_val, llvm::MaybeAlign prevalign, + bool isVolatile, llvm::AtomicOrdering ordering, + llvm::SyncScope::ID syncScope, llvm::Value *mask) { + return; + } + + bool + handleAdjointForIntrinsic(llvm::Intrinsic::ID ID, llvm::Instruction &I, + llvm::SmallVectorImpl &orig_ops) { + using namespace llvm; + + + switch (ID) { + case Intrinsic::nvvm_ldu_global_i: + case Intrinsic::nvvm_ldu_global_p: + case Intrinsic::nvvm_ldu_global_f: + case Intrinsic::nvvm_ldg_global_i: + case Intrinsic::nvvm_ldg_global_p: + case Intrinsic::nvvm_ldg_global_f: { + auto CI = cast(I.getOperand(1)); + visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue())); + return false; + } + default: + break; + } + + if (ID == Intrinsic::masked_store) { + auto align0 = cast(I.getOperand(2))->getZExtValue(); + auto align = MaybeAlign(align0); + visitCommonStore(I, /*orig_ptr*/ I.getOperand(1), + /*orig_val*/ I.getOperand(0), align, + /*isVolatile*/ false, llvm::AtomicOrdering::NotAtomic, + SyncScope::SingleThread, + /*mask*/ getNewFromOriginal(I.getOperand(3))); + return false; + } + if (ID == Intrinsic::masked_load) { + auto align0 = cast(I.getOperand(1))->getZExtValue(); + auto align = MaybeAlign(align0); + visitLoadLike(I, align, + /*mask*/ getNewFromOriginal(I.getOperand(2)), + /*orig_maskInit*/ I.getOperand(3)); + return false; + } + + auto called = cast(&I)->getCalledFunction(); + (void)called; + switch (ID) { +//#include "IntrinsicDerivatives.inc" + default: + break; + } + + switch (ID) { + case Intrinsic::nvvm_barrier0: + case Intrinsic::nvvm_barrier0_popc: + case Intrinsic::nvvm_barrier0_and: + case Intrinsic::nvvm_barrier0_or: + case Intrinsic::nvvm_membar_cta: + case Intrinsic::nvvm_membar_gl: + case Intrinsic::nvvm_membar_sys: + case Intrinsic::amdgcn_s_barrier: + return false; + default: break; + } + return true; + } + + llvm::Value *getNewFromOriginal(llvm::Value* v) { + auto found = originalToNewFn.find(v); + assert(found != originalToNewFn.end()); + return found->second; + } + + llvm::Instruction *getNewFromOriginal(llvm::Instruction* v) { + return cast(getNewFromOriginal((llvm::Value*)v)); + } + + bool handleKnownCalls(llvm::CallInst &call, llvm::Function *called, + llvm::StringRef funcName, + llvm::CallInst *const newCall) { + return false; + } + + Value* GetShadow(RequestContext &ctx, Value* v) { + if (auto F = dyn_cast(v)) + return Logic.CreateTruncate(ctx, F, fromwidth, towidth); + llvm::errs() << " unknown get truncated func: " << *v << "\n"; + llvm_unreachable("unknown get truncated func"); + return v; + } + // Return + void visitCallInst(llvm::CallInst &call) { + using namespace llvm; + + CallInst *const newCall = cast(getNewFromOriginal(&call)); + IRBuilder<> BuilderZ(newCall); + + if (auto called = call.getCalledFunction()) + if (handleKnownCalls(call, called, getFuncNameFromCall(&call), + newCall)) + return; + + RequestContext ctx(&call, &BuilderZ); + auto val = GetShadow(ctx, getNewFromOriginal(call.getCalledOperand())); + newCall->setCalledOperand(val); + return; + } +}; + +llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context, llvm::Function *totrunc, + unsigned fromwidth, unsigned towidth){ + if (fromwidth == towidth) return totrunc; + + TruncateCacheKey tup(totrunc, fromwidth, towidth); + if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { + return TruncateCachedFunctions.find(tup)->second; + } + + FunctionType *orig_FTy = totrunc->getFunctionType(); + SmallVector params; + + for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) { + params.push_back(orig_FTy->getParamType(i)); + } + + Type *NewTy = totrunc->getReturnType(); + + FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); + Function *NewF = + Function::Create(FTy, totrunc->getLinkage(), + "trunc_" + std::to_string(fromwidth) + "_" + std::to_string(towidth) + totrunc->getName(), totrunc->getParent()); + + NewF->setLinkage(Function::LinkageTypes::InternalLinkage); + + TruncateCachedFunctions[tup] = NewF; + + if (totrunc->empty()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No truncate mode found for " + totrunc->getName() << "\n"; + llvm::Value *toshow = totrunc; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *totrunc << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(totrunc), + wrap(context.ip)); + return NewF; + } + if (context.req) { + EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, + ss.str()); + return NewF; + } + llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; + llvm::errs() << *totrunc << "\n"; + llvm_unreachable("attempting to truncate function without definition"); + } + + if (fromwidth < towidth) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Cannot truncate into a large width\n"; + llvm::Value *toshow = totrunc; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *totrunc << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(totrunc), + wrap(context.ip)); + return NewF; + } + if (context.req) { + EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, + ss.str()); + return NewF; + } + llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; + llvm::errs() << *totrunc << "\n"; + llvm_unreachable("attempting to truncate function without definition"); + } + + + ValueToValueMapTy originalToNewFn; + + for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); i != totrunc->arg_end();) { + originalToNewFn[i] = j; + j->setName(i->getName()); + ++j; + ++i; + } + + SmallVector Returns; +#if LLVM_VERSION_MAJOR >= 13 + CloneFunctionInto(NewF, totrunc, originalToNewFn, + CloneFunctionChangeType::LocalChangesOnly, Returns, "", + nullptr); +#else + CloneFunctionInto(NewF, totrunc, originalToNewFn, true, Returns, "", nullptr); +#endif + + NewF->setLinkage(Function::LinkageTypes::InternalLinkage); + + TruncateGenerator handle(originalToNewFn, fromwidth, towidth, totrunc, NewF, *this); + for (auto &BB : *totrunc) + for (auto &I : BB) + handle.visit(&I); + + if (llvm::verifyFunction(*NewF, &llvm::errs())) { + llvm::errs() << *totrunc << "\n"; + llvm::errs() << *NewF << "\n"; + report_fatal_error("function failed verification (5)"); + } + + return NewF; +} + llvm::Function *EnzymeLogic::CreateBatch(RequestContext context, Function *tobatch, unsigned width, ArrayRef arg_types, diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index c7f7c4bae86e..6a585a60f12e 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -510,6 +510,11 @@ class EnzymeLogic { llvm::ArrayRef arg_types, BATCH_TYPE ret_type); + using TruncateCacheKey = std::tuple; + std::map TruncateCachedFunctions; + llvm::Function *CreateTruncate(RequestContext context, llvm::Function *tobatch, + unsigned fromwidth, unsigned towidth); + /// Create a traced version of a function /// \p context the instruction which requested this trace (or null). /// \p totrace is the function to trace diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index c289c709deed..b7b7b5f31f29 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -83,6 +83,7 @@ enum class ErrorType { MixedActivityError = 7, IllegalReplaceFicticiousPHIs = 8, GetIndexError = 9, + NoTruncate = 10, }; extern "C" { diff --git a/enzyme/test/Enzyme/CMakeLists.txt b/enzyme/test/Enzyme/CMakeLists.txt index 0187644409f2..d88af6ddd95e 100644 --- a/enzyme/test/Enzyme/CMakeLists.txt +++ b/enzyme/test/Enzyme/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Sparse) +add_subdirectory(Truncate) add_subdirectory(ReverseMode) add_subdirectory(ReverseModeVector) add_subdirectory(ForwardMode) diff --git a/enzyme/test/Enzyme/Truncate/CMakeLists.txt b/enzyme/test/Enzyme/Truncate/CMakeLists.txt new file mode 100644 index 000000000000..79e649ab8e4b --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-trunc "Running enzyme truncation tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-enzyme-trunc PROPERTIES FOLDER "Tests") + +# add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} +# DEPENDS ${ENZYME_TEST_DEPS} +# ) diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll new file mode 100644 index 000000000000..b0cea56f2292 --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -0,0 +1,43 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s + +define void @f(double* %x) { + %y = load double, double* %x + %m = fmul double %y, %y + store double %m, double* %x + ret void +} + +declare void (double*)* @__enzyme_truncate(...) + +define void @tester(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate(void (double*)* @f, i64 64, i64 32) + call void %ptr(double* %data) + ret void +} + +; CHECK: define void @tester(double* %data) +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @trunc_64_32f(double* %data) +; CHECK-NEXT: ret void +; CHECK-NEXT: } + +; CHECK: define internal void @trunc_64_32f(double* %x) +; CHECK-NEXT: %1 = alloca double, align 8 +; CHECK-NEXT: %y = load double, double* %x, align 8 +; CHECK-NEXT: store double %y, double* %1, align 8 +; CHECK-NEXT: %2 = bitcast double* %1 to float* +; CHECK-NEXT: %3 = load float, float* %2, align 4 +; CHECK-NEXT: store double %y, double* %1, align 8 +; CHECK-NEXT: %4 = bitcast double* %1 to float* +; CHECK-NEXT: %5 = load float, float* %4, align 4 +; CHECK-NEXT: %m = fmul float %5, %3 +; CHECK-NEXT: %6 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %6, align 4 +; CHECK-NEXT: %7 = bitcast double* %1 to float* +; CHECK-NEXT: store float %m, float* %7, align 4 +; CHECK-NEXT: %8 = load double, double* %1, align 8 +; CHECK-NEXT: store double %8, double* %x, align 8 +; CHECK-NEXT: ret void +; CHECK-NEXT: } \ No newline at end of file From 337e84717186a47d28e390f1041dedef2e016dc6 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 21 Dec 2023 08:19:02 +0900 Subject: [PATCH 02/13] Add FCmp and cast --- enzyme/Enzyme/EnzymeLogic.cpp | 342 ++++++++++++++++++---------------- 1 file changed, 181 insertions(+), 161 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 69ff0860bf17..bb7d0981f30a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4817,18 +4817,21 @@ Function *EnzymeLogic::CreateForwardDiff( class TruncateGenerator : public llvm::InstVisitor { private: -ValueToValueMapTy &originalToNewFn; -unsigned fromwidth; -unsigned towidth; -Function* oldFunc; -Function* newFunc; -AllocaInst* tmpBlock; -EnzymeLogic &Logic; + ValueToValueMapTy &originalToNewFn; + unsigned fromwidth; + unsigned towidth; + Function *oldFunc; + Function *newFunc; + AllocaInst *tmpBlock; + EnzymeLogic &Logic; public: -TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsigned towidth, Function* oldFunc, Function* newFunc, EnzymeLogic& Logic) : - originalToNewFn(originalToNewFn), fromwidth(fromwidth), towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) { - IRBuilder <> B(&newFunc->getEntryBlock().front()); + TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, + unsigned towidth, Function *oldFunc, Function *newFunc, + EnzymeLogic &Logic) + : originalToNewFn(originalToNewFn), fromwidth(fromwidth), + towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) { + IRBuilder<> B(&newFunc->getEntryBlock().front()); tmpBlock = B.CreateAlloca(getTypeForWidth(fromwidth)); } @@ -4839,7 +4842,7 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign // below switch (inst.getOpcode()) { -//#include "InstructionDerivatives.inc" + //#include "InstructionDerivatives.inc" default: break; } @@ -4847,29 +4850,41 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign todo(inst); } - Type* getTypeForWidth(unsigned width) { - switch(width){ - default: - return llvm::Type::getIntNTy(oldFunc->getContext(), width); - case 64: - return llvm::Type::getDoubleTy(oldFunc->getContext()); - case 32: - return llvm::Type::getFloatTy(oldFunc->getContext()); - case 16: - return llvm::Type::getHalfTy(oldFunc->getContext()); + Type *getTypeForWidth(unsigned width) { + switch (width) { + default: + return llvm::Type::getIntNTy(oldFunc->getContext(), width); + case 64: + return llvm::Type::getDoubleTy(oldFunc->getContext()); + case 32: + return llvm::Type::getFloatTy(oldFunc->getContext()); + case 16: + return llvm::Type::getHalfTy(oldFunc->getContext()); } } - Value *truncate(IRBuilder<> &B, Value* v) { - Type* nextType = getTypeForWidth(towidth); - B.CreateStore(v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); - return B.CreateLoad(nextType, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(nextType))); + + Type *getFromType() { return getTypeForWidth(fromwidth); } + + Type *getToType() { return getTypeForWidth(towidth); } + + Value *truncate(IRBuilder<> &B, Value *v) { + Type *nextType = getTypeForWidth(towidth); + B.CreateStore( + v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); + return B.CreateLoad( + nextType, + B.CreatePointerCast(tmpBlock, PointerType::getUnqual(nextType))); } - Value *expand(IRBuilder<> &B, Value* v, Type* origT) { - auto c0 = Constant::getNullValue(llvm::Type::getIntNTy(oldFunc->getContext(), fromwidth)); - B.CreateStore(c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); - B.CreateStore(v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); - return B.CreateLoad(origT, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(origT))); + Value *expand(IRBuilder<> &B, Value *v, Type *origT) { + auto c0 = Constant::getNullValue( + llvm::Type::getIntNTy(oldFunc->getContext(), fromwidth)); + B.CreateStore(c0, B.CreatePointerCast( + tmpBlock, PointerType::getUnqual(c0->getType()))); + B.CreateStore( + v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType()))); + return B.CreateLoad( + origT, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(origT))); } void todo(llvm::Instruction &I) { @@ -4887,14 +4902,18 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign } } - void visitAllocaInst(llvm::AllocaInst &I) { - return; - } - void visitICmpInst(llvm::ICmpInst &I) { - return; - } - void visitFCmpInst(llvm::FCmpInst &I) { - todo(I); + void visitAllocaInst(llvm::AllocaInst &I) { return; } + void visitICmpInst(llvm::ICmpInst &I) { return; } + void visitFCmpInst(llvm::FCmpInst &CI) { + auto newI = getNewFromOriginal(&CI); + IRBuilder<> B(newI); + auto nres = cast(B.CreateFCmp( + CI.getPredicate(), truncate(B, getNewFromOriginal(CI.getOperand(0))), + truncate(B, getNewFromOriginal(CI.getOperand(1))))); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres, CI.getType())); + newI->eraseFromParent(); return; } void visitLoadInst(llvm::LoadInst &LI) { @@ -4907,116 +4926,123 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), /*mask=*/nullptr); } - void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { - return; - } - void visitPHINode(llvm::PHINode &phi) { - return; - } - void visitCastInst(llvm::CastInst &phi) { - todo(phi); + void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } + void visitPHINode(llvm::PHINode &phi) { return; } + void visitCastInst(llvm::CastInst &CI) { + Value *newCI = nullptr; + auto newI = getNewFromOriginal(&CI); + std::string oldName = CI.getName().str(); + newI->setName(""); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (CI.getDestTy() == getToType()) { + auto newI = getNewFromOriginal(&CI); + IRBuilder<> B(newI); + newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (newCI) { + newI->replaceAllUsesWith(newCI); + newI->eraseFromParent(); + } return; } void visitSelectInst(llvm::SelectInst &SI) { todo(SI); return; } - void visitExtractElementInst(llvm::ExtractElementInst &EEI) { - return; - } - void visitInsertElementInst(llvm::InsertElementInst &EEI) { - return; - } - void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { - return; - } - void visitExtractValueInst(llvm::ExtractValueInst &EEI) { - return; - } - void visitInsertValueInst(llvm::InsertValueInst &EEI) { - return; - } + void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } + void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } + void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { return; } + void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } + void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } void visitBinaryOperator(llvm::BinaryOperator &BO) { - switch(BO.getOpcode()) { - default: break; - case BinaryOperator::Add: - case BinaryOperator::Sub: - case BinaryOperator::Mul: - case BinaryOperator::UDiv: - case BinaryOperator::SDiv: - case BinaryOperator::URem: - case BinaryOperator::SRem: - case BinaryOperator::AShr: - case BinaryOperator::LShr: - case BinaryOperator::Shl: - case BinaryOperator::And: - case BinaryOperator::Or: - case BinaryOperator::Xor: - return; + switch (BO.getOpcode()) { + default: + break; + case BinaryOperator::Add: + case BinaryOperator::Sub: + case BinaryOperator::Mul: + case BinaryOperator::UDiv: + case BinaryOperator::SDiv: + case BinaryOperator::URem: + case BinaryOperator::SRem: + case BinaryOperator::AShr: + case BinaryOperator::LShr: + case BinaryOperator::Shl: + case BinaryOperator::And: + case BinaryOperator::Or: + case BinaryOperator::Xor: + return; } if (towidth == 32 || towidth == 16 || towidth == 64) { auto newI = getNewFromOriginal(&BO); IRBuilder<> B(newI); - switch(BO.getOpcode()) { - default: break; - case BinaryOperator::FMul: - { - auto nres = cast(B.CreateFMul(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + switch (BO.getOpcode()) { + default: + break; + case BinaryOperator::FMul: { + auto nres = cast( + B.CreateFMul(truncate(B, getNewFromOriginal(BO.getOperand(0))), + truncate(B, getNewFromOriginal(BO.getOperand(1))))); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); newI->eraseFromParent(); - } + } return; - case BinaryOperator::FAdd: - { - auto nres = cast(B.CreateFAdd(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + case BinaryOperator::FAdd: { + auto nres = cast( + B.CreateFAdd(truncate(B, getNewFromOriginal(BO.getOperand(0))), + truncate(B, getNewFromOriginal(BO.getOperand(1))))); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); newI->eraseFromParent(); - } + } return; - case BinaryOperator::FSub: - { - auto nres = cast(B.CreateFSub(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + case BinaryOperator::FSub: { + auto nres = cast( + B.CreateFSub(truncate(B, getNewFromOriginal(BO.getOperand(0))), + truncate(B, getNewFromOriginal(BO.getOperand(1))))); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); newI->eraseFromParent(); - } + } return; - case BinaryOperator::FDiv: - { - auto nres = cast(B.CreateFDiv(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + case BinaryOperator::FDiv: { + auto nres = cast( + B.CreateFDiv(truncate(B, getNewFromOriginal(BO.getOperand(0))), + truncate(B, getNewFromOriginal(BO.getOperand(1))))); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); newI->eraseFromParent(); - } + } return; - case BinaryOperator::FRem: - { - auto nres = cast(B.CreateFRem(truncate(B, getNewFromOriginal(BO.getOperand(0))), truncate(B, getNewFromOriginal(BO.getOperand(1))))); + case BinaryOperator::FRem: { + auto nres = cast( + B.CreateFRem(truncate(B, getNewFromOriginal(BO.getOperand(0))), + truncate(B, getNewFromOriginal(BO.getOperand(1))))); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); newI->eraseFromParent(); - } + } return; - } + } } todo(BO); return; } - void visitMemSetInst(llvm::MemSetInst &MS) { - visitMemSetCommon(MS); - } - void visitMemSetCommon(llvm::CallInst &MS) { - return; - } + void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); } + void visitMemSetCommon(llvm::CallInst &MS) { return; } void visitMemTransferInst(llvm::MemTransferInst &MTI) { using namespace llvm; Value *isVolatile = getNewFromOriginal(MTI.getOperand(3)); @@ -5024,8 +5050,7 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign auto dstAlign = MTI.getDestAlign(); visitMemTransferCommon(MTI.getIntrinsicID(), srcAlign, dstAlign, MTI, MTI.getOperand(0), MTI.getOperand(1), - getNewFromOriginal(MTI.getOperand(2)), - isVolatile); + getNewFromOriginal(MTI.getOperand(2)), isVolatile); } void visitMemTransferCommon(llvm::Intrinsic::ID ID, llvm::MaybeAlign srcAlign, llvm::MaybeAlign dstAlign, llvm::CallInst &MTI, @@ -5033,9 +5058,7 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign llvm::Value *new_size, llvm::Value *isVolatile) { return; } - void visitFenceInst(llvm::FenceInst &FI) { - return; - } + void visitFenceInst(llvm::FenceInst &FI) { return; } void visitIntrinsicInst(llvm::IntrinsicInst &II) { SmallVector orig_ops(II.getNumOperands()); for (unsigned i = 0; i < II.getNumOperands(); ++i) { @@ -5047,19 +5070,11 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign return; } - void visitReturnInst(llvm::ReturnInst &I) { - return; - } + void visitReturnInst(llvm::ReturnInst &I) { return; } - void visitBranchInst(llvm::BranchInst &I) { - return; - } - void visitSwitchInst(llvm::SwitchInst &I) { - return; - } - void visitUnreachableInst(llvm::UnreachableInst &I) { - return; - } + void visitBranchInst(llvm::BranchInst &I) { return; } + void visitSwitchInst(llvm::SwitchInst &I) { return; } + void visitUnreachableInst(llvm::UnreachableInst &I) { return; } void visitLoadLike(llvm::Instruction &I, llvm::MaybeAlign alignment, llvm::Value *mask = nullptr, llvm::Value *orig_maskInit = nullptr) { @@ -5070,15 +5085,14 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign llvm::Value *orig_val, llvm::MaybeAlign prevalign, bool isVolatile, llvm::AtomicOrdering ordering, llvm::SyncScope::ID syncScope, llvm::Value *mask) { - return; - } + return; + } bool handleAdjointForIntrinsic(llvm::Intrinsic::ID ID, llvm::Instruction &I, llvm::SmallVectorImpl &orig_ops) { using namespace llvm; - switch (ID) { case Intrinsic::nvvm_ldu_global_i: case Intrinsic::nvvm_ldu_global_p: @@ -5116,60 +5130,60 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign auto called = cast(&I)->getCalledFunction(); (void)called; switch (ID) { -//#include "IntrinsicDerivatives.inc" + //#include "IntrinsicDerivatives.inc" default: break; } - switch (ID) { - case Intrinsic::nvvm_barrier0: - case Intrinsic::nvvm_barrier0_popc: - case Intrinsic::nvvm_barrier0_and: - case Intrinsic::nvvm_barrier0_or: - case Intrinsic::nvvm_membar_cta: - case Intrinsic::nvvm_membar_gl: - case Intrinsic::nvvm_membar_sys: - case Intrinsic::amdgcn_s_barrier: - return false; - default: break; - } - return true; + switch (ID) { + case Intrinsic::nvvm_barrier0: + case Intrinsic::nvvm_barrier0_popc: + case Intrinsic::nvvm_barrier0_and: + case Intrinsic::nvvm_barrier0_or: + case Intrinsic::nvvm_membar_cta: + case Intrinsic::nvvm_membar_gl: + case Intrinsic::nvvm_membar_sys: + case Intrinsic::amdgcn_s_barrier: + return false; + default: + break; + } + return true; } - llvm::Value *getNewFromOriginal(llvm::Value* v) { + llvm::Value *getNewFromOriginal(llvm::Value *v) { auto found = originalToNewFn.find(v); assert(found != originalToNewFn.end()); return found->second; } - llvm::Instruction *getNewFromOriginal(llvm::Instruction* v) { - return cast(getNewFromOriginal((llvm::Value*)v)); + llvm::Instruction *getNewFromOriginal(llvm::Instruction *v) { + return cast(getNewFromOriginal((llvm::Value *)v)); } bool handleKnownCalls(llvm::CallInst &call, llvm::Function *called, - llvm::StringRef funcName, - llvm::CallInst *const newCall) { - return false; - } + llvm::StringRef funcName, + llvm::CallInst *const newCall) { + return false; + } - Value* GetShadow(RequestContext &ctx, Value* v) { + Value *GetShadow(RequestContext &ctx, Value *v) { if (auto F = dyn_cast(v)) return Logic.CreateTruncate(ctx, F, fromwidth, towidth); llvm::errs() << " unknown get truncated func: " << *v << "\n"; llvm_unreachable("unknown get truncated func"); return v; } - // Return + // Return void visitCallInst(llvm::CallInst &call) { using namespace llvm; CallInst *const newCall = cast(getNewFromOriginal(&call)); IRBuilder<> BuilderZ(newCall); - if (auto called = call.getCalledFunction()) - if (handleKnownCalls(call, called, getFuncNameFromCall(&call), - newCall)) - return; + if (auto called = call.getCalledFunction()) + if (handleKnownCalls(call, called, getFuncNameFromCall(&call), newCall)) + return; RequestContext ctx(&call, &BuilderZ); auto val = GetShadow(ctx, getNewFromOriginal(call.getCalledOperand())); @@ -5178,10 +5192,13 @@ TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, unsign } }; -llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context, llvm::Function *totrunc, - unsigned fromwidth, unsigned towidth){ - if (fromwidth == towidth) return totrunc; - +llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context, + llvm::Function *totrunc, + unsigned fromwidth, + unsigned towidth) { + if (fromwidth == towidth) + return totrunc; + TruncateCacheKey tup(totrunc, fromwidth, towidth); if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { return TruncateCachedFunctions.find(tup)->second; @@ -5199,7 +5216,9 @@ llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context, llvm::Functi FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); Function *NewF = Function::Create(FTy, totrunc->getLinkage(), - "trunc_" + std::to_string(fromwidth) + "_" + std::to_string(towidth) + totrunc->getName(), totrunc->getParent()); + "trunc_" + std::to_string(fromwidth) + "_" + + std::to_string(towidth) + totrunc->getName(), + totrunc->getParent()); NewF->setLinkage(Function::LinkageTypes::InternalLinkage); @@ -5259,10 +5278,10 @@ llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context, llvm::Functi llvm_unreachable("attempting to truncate function without definition"); } - ValueToValueMapTy originalToNewFn; - for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); i != totrunc->arg_end();) { + for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); + i != totrunc->arg_end();) { originalToNewFn[i] = j; j->setName(i->getName()); ++j; @@ -5280,7 +5299,8 @@ llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context, llvm::Functi NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - TruncateGenerator handle(originalToNewFn, fromwidth, towidth, totrunc, NewF, *this); + TruncateGenerator handle(originalToNewFn, fromwidth, towidth, totrunc, NewF, + *this); for (auto &BB : *totrunc) for (auto &I : BB) handle.visit(&I); From af7cea5e44a58fe94bdb4296eefd4bd7dc48e73d Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Sun, 31 Dec 2023 16:39:33 +0900 Subject: [PATCH 03/13] Fix cmp --- enzyme/Enzyme/EnzymeLogic.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index bb7d0981f30a..7f2495baf85c 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4912,7 +4912,7 @@ class TruncateGenerator : public llvm::InstVisitor { truncate(B, getNewFromOriginal(CI.getOperand(1))))); nres->takeName(newI); nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres, CI.getType())); + newI->replaceAllUsesWith(nres); newI->eraseFromParent(); return; } From 8ce75d1bf0262ee02ae029c430988cfc7b3f0a60 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Sun, 31 Dec 2023 16:39:50 +0900 Subject: [PATCH 04/13] Add cmp test --- enzyme/test/Enzyme/Truncate/cmp.ll | 34 ++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 enzyme/test/Enzyme/Truncate/cmp.ll diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll new file mode 100644 index 000000000000..7b7581131ac5 --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -0,0 +1,34 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s + +define i1 @f(double %x, double %y) { + %res = fcmp olt double %x, %y + ret i1 %res +} + +declare i1 (double, double)* @__enzyme_truncate(...) + +define i1 @tester(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate(i1 (double, double)* @f, i64 64, i64 32) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} + +; CHECK: define i1 @tester(double %x, double %y) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %res = call i1 @trunc_64_32f(double %x, double %y) +; CHECK-NEXT: ret i1 %res +; CHECK-NEXT: } + +; CHECK: define internal i1 @trunc_64_32f(double %x, double %y) { +; CHECK-NEXT: %1 = alloca double, align 8 +; CHECK-NEXT: store double %y, double* %1, align 8 +; CHECK-NEXT: %2 = bitcast double* %1 to float* +; CHECK-NEXT: %3 = load float, float* %2, align 4 +; CHECK-NEXT: store double %x, double* %1, align 8 +; CHECK-NEXT: %4 = bitcast double* %1 to float* +; CHECK-NEXT: %5 = load float, float* %4, align 4 +; CHECK-NEXT: %res = fcmp olt float %5, %3 +; CHECK-NEXT: ret i1 %res +; CHECK-NEXT: } From 00ce1da65d766a74108c0b1929edaee6530ab161 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Sun, 31 Dec 2023 16:52:53 +0900 Subject: [PATCH 05/13] Add select --- enzyme/Enzyme/EnzymeLogic.cpp | 11 +++++++- enzyme/test/Enzyme/Truncate/select.ll | 39 +++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/Enzyme/Truncate/select.ll diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 7f2495baf85c..1077c8e971fb 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4951,7 +4951,16 @@ class TruncateGenerator : public llvm::InstVisitor { return; } void visitSelectInst(llvm::SelectInst &SI) { - todo(SI); + auto newI = getNewFromOriginal(&SI); + IRBuilder<> B(newI); + auto nres = cast( + B.CreateSelect(getNewFromOriginal(SI.getCondition()), + truncate(B, getNewFromOriginal(SI.getTrueValue())), + truncate(B, getNewFromOriginal(SI.getFalseValue())))); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres, SI.getType())); + newI->eraseFromParent(); return; } void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll new file mode 100644 index 000000000000..db3b9d511afc --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -0,0 +1,39 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s + +define double @f(double %x, double %y, i1 %cond) { + %res = select i1 %cond, double %x, double %y + ret double %res +} + +declare double (double, double, i1)* @__enzyme_truncate(...) + +define double @tester(double %x, double %y, i1 %cond) { +entry: + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate(double (double, double, i1)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y, i1 %cond) + ret double %res +} + +; CHECK: define double @tester(double %x, double %y, i1 %cond) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %res = call double @trunc_64_32f(double %x, double %y, i1 %cond) +; CHECK-NEXT: ret double %res +; CHECK-NEXT: } + +; CHECK: define internal double @trunc_64_32f(double %x, double %y, i1 %cond) { +; CHECK-NEXT: %1 = alloca double, align 8 +; CHECK-NEXT: store double %y, double* %1, align 8 +; CHECK-NEXT: %2 = bitcast double* %1 to float* +; CHECK-NEXT: %3 = load float, float* %2, align 4 +; CHECK-NEXT: store double %x, double* %1, align 8 +; CHECK-NEXT: %4 = bitcast double* %1 to float* +; CHECK-NEXT: %5 = load float, float* %4, align 4 +; CHECK-NEXT: %res = select i1 %cond, float %5, float %3 +; CHECK-NEXT: %6 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %6, align 4 +; CHECK-NEXT: %7 = bitcast double* %1 to float* +; CHECK-NEXT: store float %res, float* %7, align 4 +; CHECK-NEXT: %8 = load double, double* %1, align 8 +; CHECK-NEXT: ret double %8 +; CHECK-NEXT: } From d549b0f95ee105c1ea972476e23f03b6ca1ec399 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Sun, 31 Dec 2023 16:56:10 +0900 Subject: [PATCH 06/13] clang-format --- enzyme/Enzyme/Enzyme.cpp | 12 +++++++----- enzyme/Enzyme/EnzymeLogic.cpp | 4 ++-- enzyme/Enzyme/EnzymeLogic.h | 5 +++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index d4435b483682..f1c762b45774 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1321,17 +1321,19 @@ class EnzymeBase { if (!F) return false; if (CI->arg_size() != 3) { - EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, - "Had incorrect number of args to __enzyme_truncate", *CI, - " - expected 3"); - return false; + EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, + "Had incorrect number of args to __enzyme_truncate", *CI, + " - expected 3"); + return false; } auto Cfrom = cast(CI->getArgOperand(1)); assert(Cfrom); auto Cto = cast(CI->getArgOperand(2)); assert(Cto); RequestContext context(CI, &Builder); - llvm::Value* res = Logic.CreateTruncate(context, F, (unsigned)Cfrom->getValue().getZExtValue(), (unsigned)Cto->getValue().getZExtValue() ); + llvm::Value *res = Logic.CreateTruncate( + context, F, (unsigned)Cfrom->getValue().getZExtValue(), + (unsigned)Cto->getValue().getZExtValue()); if (!res) return false; res = Builder.CreatePointerCast(res, CI->getType()); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 1077c8e971fb..fd881ee69efa 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4842,7 +4842,7 @@ class TruncateGenerator : public llvm::InstVisitor { // below switch (inst.getOpcode()) { - //#include "InstructionDerivatives.inc" + // #include "InstructionDerivatives.inc" default: break; } @@ -5139,7 +5139,7 @@ class TruncateGenerator : public llvm::InstVisitor { auto called = cast(&I)->getCalledFunction(); (void)called; switch (ID) { - //#include "IntrinsicDerivatives.inc" + // #include "IntrinsicDerivatives.inc" default: break; } diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 6a585a60f12e..4ce25e8ae465 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -512,8 +512,9 @@ class EnzymeLogic { using TruncateCacheKey = std::tuple; std::map TruncateCachedFunctions; - llvm::Function *CreateTruncate(RequestContext context, llvm::Function *tobatch, - unsigned fromwidth, unsigned towidth); + llvm::Function *CreateTruncate(RequestContext context, + llvm::Function *tobatch, unsigned fromwidth, + unsigned towidth); /// Create a traced version of a function /// \p context the instruction which requested this trace (or null). From 963fe47e393c980ece52ba408cf585bea0807561 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Sun, 31 Dec 2023 17:06:59 +0900 Subject: [PATCH 07/13] Weird newlines --- enzyme/Enzyme/Enzyme.cpp | 1 - enzyme/test/Enzyme/Truncate/simple.ll | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index f1c762b45774..a514a826351b 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1,4 +1,3 @@ - //===- Enzyme.cpp - Automatic Differentiation Transformation Pass -------===// // // Enzyme Project diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index b0cea56f2292..79f6c269b7b4 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -40,4 +40,4 @@ entry: ; CHECK-NEXT: %8 = load double, double* %1, align 8 ; CHECK-NEXT: store double %8, double* %x, align 8 ; CHECK-NEXT: ret void -; CHECK-NEXT: } \ No newline at end of file +; CHECK-NEXT: } From 15de122c66b83aad30ee3069d2199b560b572610 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 1 Jan 2024 03:00:06 +0900 Subject: [PATCH 08/13] Make output reproducable --- enzyme/Enzyme/EnzymeLogic.cpp | 35 +++++++++++---------------- enzyme/test/Enzyme/Truncate/cmp.ll | 20 +++++++-------- enzyme/test/Enzyme/Truncate/select.ll | 28 ++++++++++----------- enzyme/test/Enzyme/Truncate/simple.ll | 33 ++++++++++++------------- 4 files changed, 54 insertions(+), 62 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index fd881ee69efa..bc9402451734 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4907,9 +4907,10 @@ class TruncateGenerator : public llvm::InstVisitor { void visitFCmpInst(llvm::FCmpInst &CI) { auto newI = getNewFromOriginal(&CI); IRBuilder<> B(newI); - auto nres = cast(B.CreateFCmp( - CI.getPredicate(), truncate(B, getNewFromOriginal(CI.getOperand(0))), - truncate(B, getNewFromOriginal(CI.getOperand(1))))); + auto truncLHS = truncate(B, getNewFromOriginal(CI.getOperand(0))); + auto truncRHS = truncate(B, getNewFromOriginal(CI.getOperand(1))); + auto nres = + cast(B.CreateFCmp(CI.getPredicate(), truncLHS, truncRHS)); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(nres); @@ -4953,10 +4954,10 @@ class TruncateGenerator : public llvm::InstVisitor { void visitSelectInst(llvm::SelectInst &SI) { auto newI = getNewFromOriginal(&SI); IRBuilder<> B(newI); + auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); + auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); auto nres = cast( - B.CreateSelect(getNewFromOriginal(SI.getCondition()), - truncate(B, getNewFromOriginal(SI.getTrueValue())), - truncate(B, getNewFromOriginal(SI.getFalseValue())))); + B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, SI.getType())); @@ -4992,13 +4993,13 @@ class TruncateGenerator : public llvm::InstVisitor { if (towidth == 32 || towidth == 16 || towidth == 64) { auto newI = getNewFromOriginal(&BO); IRBuilder<> B(newI); + auto newLHS = truncate(B, getNewFromOriginal(BO.getOperand(0))); + auto newRHS = truncate(B, getNewFromOriginal(BO.getOperand(1))); switch (BO.getOpcode()) { default: break; case BinaryOperator::FMul: { - auto nres = cast( - B.CreateFMul(truncate(B, getNewFromOriginal(BO.getOperand(0))), - truncate(B, getNewFromOriginal(BO.getOperand(1))))); + auto nres = cast(B.CreateFMul(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); @@ -5006,9 +5007,7 @@ class TruncateGenerator : public llvm::InstVisitor { } return; case BinaryOperator::FAdd: { - auto nres = cast( - B.CreateFAdd(truncate(B, getNewFromOriginal(BO.getOperand(0))), - truncate(B, getNewFromOriginal(BO.getOperand(1))))); + auto nres = cast(B.CreateFAdd(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); @@ -5016,9 +5015,7 @@ class TruncateGenerator : public llvm::InstVisitor { } return; case BinaryOperator::FSub: { - auto nres = cast( - B.CreateFSub(truncate(B, getNewFromOriginal(BO.getOperand(0))), - truncate(B, getNewFromOriginal(BO.getOperand(1))))); + auto nres = cast(B.CreateFSub(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); @@ -5026,9 +5023,7 @@ class TruncateGenerator : public llvm::InstVisitor { } return; case BinaryOperator::FDiv: { - auto nres = cast( - B.CreateFDiv(truncate(B, getNewFromOriginal(BO.getOperand(0))), - truncate(B, getNewFromOriginal(BO.getOperand(1))))); + auto nres = cast(B.CreateFDiv(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); @@ -5036,9 +5031,7 @@ class TruncateGenerator : public llvm::InstVisitor { } return; case BinaryOperator::FRem: { - auto nres = cast( - B.CreateFRem(truncate(B, getNewFromOriginal(BO.getOperand(0))), - truncate(B, getNewFromOriginal(BO.getOperand(1))))); + auto nres = cast(B.CreateFRem(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); newI->replaceAllUsesWith(expand(B, nres, BO.getType())); diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index 7b7581131ac5..3c2cffec9979 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -22,13 +22,13 @@ entry: ; CHECK-NEXT: } ; CHECK: define internal i1 @trunc_64_32f(double %x, double %y) { -; CHECK-NEXT: %1 = alloca double, align 8 -; CHECK-NEXT: store double %y, double* %1, align 8 -; CHECK-NEXT: %2 = bitcast double* %1 to float* -; CHECK-NEXT: %3 = load float, float* %2, align 4 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %4 = bitcast double* %1 to float* -; CHECK-NEXT: %5 = load float, float* %4, align 4 -; CHECK-NEXT: %res = fcmp olt float %5, %3 -; CHECK-NEXT: ret i1 %res -; CHECK-NEXT: } +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %res = fcmp olt float %3, %5 +; CHECK-DAG: ret i1 %res +; CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll index db3b9d511afc..ae539469b9a2 100644 --- a/enzyme/test/Enzyme/Truncate/select.ll +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -22,18 +22,18 @@ entry: ; CHECK-NEXT: } ; CHECK: define internal double @trunc_64_32f(double %x, double %y, i1 %cond) { -; CHECK-NEXT: %1 = alloca double, align 8 -; CHECK-NEXT: store double %y, double* %1, align 8 -; CHECK-NEXT: %2 = bitcast double* %1 to float* -; CHECK-NEXT: %3 = load float, float* %2, align 4 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %4 = bitcast double* %1 to float* -; CHECK-NEXT: %5 = load float, float* %4, align 4 -; CHECK-NEXT: %res = select i1 %cond, float %5, float %3 -; CHECK-NEXT: %6 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %6, align 4 -; CHECK-NEXT: %7 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res, float* %7, align 4 -; CHECK-NEXT: %8 = load double, double* %1, align 8 -; CHECK-NEXT: ret double %8 +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %res = select i1 %cond, float %3, float %5 +; CHECK-DAG: %6 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %6, align 4 +; CHECK-DAG: %7 = bitcast double* %1 to float* +; CHECK-DAG: store float %res, float* %7, align 4 +; CHECK-DAG: %8 = load double, double* %1, align 8 +; CHECK-DAG: ret double %8 ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 79f6c269b7b4..69990236a29e 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -24,20 +24,19 @@ entry: ; CHECK-NEXT: } ; CHECK: define internal void @trunc_64_32f(double* %x) -; CHECK-NEXT: %1 = alloca double, align 8 -; CHECK-NEXT: %y = load double, double* %x, align 8 -; CHECK-NEXT: store double %y, double* %1, align 8 -; CHECK-NEXT: %2 = bitcast double* %1 to float* -; CHECK-NEXT: %3 = load float, float* %2, align 4 -; CHECK-NEXT: store double %y, double* %1, align 8 -; CHECK-NEXT: %4 = bitcast double* %1 to float* -; CHECK-NEXT: %5 = load float, float* %4, align 4 -; CHECK-NEXT: %m = fmul float %5, %3 -; CHECK-NEXT: %6 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %6, align 4 -; CHECK-NEXT: %7 = bitcast double* %1 to float* -; CHECK-NEXT: store float %m, float* %7, align 4 -; CHECK-NEXT: %8 = load double, double* %1, align 8 -; CHECK-NEXT: store double %8, double* %x, align 8 -; CHECK-NEXT: ret void -; CHECK-NEXT: } +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %m = fmul float %3, %5 +; CHECK-DAG: %6 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %6, align 4 +; CHECK-DAG: %7 = bitcast double* %1 to float* +; CHECK-DAG: store float %m, float* %7, align 4 +; CHECK-DAG: %8 = load double, double* %1, align 8 +; CHECK-DAG: store double %8, double* %x, align 8 +; CHECK-DAG: ret void From 8fa918841eb72df6cd981a380e4f5b0f5433a2bd Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Fri, 12 Jan 2024 05:31:33 +0900 Subject: [PATCH 09/13] Handle math intrinsics --- enzyme/Enzyme/EnzymeLogic.cpp | 84 ++++++++++++++---------- enzyme/test/Enzyme/Truncate/intrinsic.ll | 23 +++++++ 2 files changed, 71 insertions(+), 36 deletions(-) create mode 100644 enzyme/test/Enzyme/Truncate/intrinsic.ll diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index bc9402451734..89a1259f41b3 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -29,6 +29,7 @@ //===----------------------------------------------------------------------===// #include "ActivityAnalysis.h" #include "AdjointGenerator.h" +#include "llvm/IR/Intrinsics.h" #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -4876,7 +4877,8 @@ class TruncateGenerator : public llvm::InstVisitor { B.CreatePointerCast(tmpBlock, PointerType::getUnqual(nextType))); } - Value *expand(IRBuilder<> &B, Value *v, Type *origT) { + Value *expand(IRBuilder<> &B, Value *v) { + Type *origT = getFromType(); auto c0 = Constant::getNullValue( llvm::Type::getIntNTy(oldFunc->getContext(), fromwidth)); B.CreateStore(c0, B.CreatePointerCast( @@ -4960,7 +4962,7 @@ class TruncateGenerator : public llvm::InstVisitor { B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); nres->takeName(newI); nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres, SI.getType())); + newI->replaceAllUsesWith(expand(B, nres)); newI->eraseFromParent(); return; } @@ -5002,7 +5004,7 @@ class TruncateGenerator : public llvm::InstVisitor { auto nres = cast(B.CreateFMul(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->replaceAllUsesWith(expand(B, nres)); newI->eraseFromParent(); } return; @@ -5010,7 +5012,7 @@ class TruncateGenerator : public llvm::InstVisitor { auto nres = cast(B.CreateFAdd(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->replaceAllUsesWith(expand(B, nres)); newI->eraseFromParent(); } return; @@ -5018,7 +5020,7 @@ class TruncateGenerator : public llvm::InstVisitor { auto nres = cast(B.CreateFSub(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->replaceAllUsesWith(expand(B, nres)); newI->eraseFromParent(); } return; @@ -5026,7 +5028,7 @@ class TruncateGenerator : public llvm::InstVisitor { auto nres = cast(B.CreateFDiv(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->replaceAllUsesWith(expand(B, nres)); newI->eraseFromParent(); } return; @@ -5034,7 +5036,7 @@ class TruncateGenerator : public llvm::InstVisitor { auto nres = cast(B.CreateFRem(newLHS, newRHS)); nres->takeName(newI); nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres, BO.getType())); + newI->replaceAllUsesWith(expand(B, nres)); newI->eraseFromParent(); } return; @@ -5062,13 +5064,44 @@ class TruncateGenerator : public llvm::InstVisitor { } void visitFenceInst(llvm::FenceInst &FI) { return; } void visitIntrinsicInst(llvm::IntrinsicInst &II) { - SmallVector orig_ops(II.getNumOperands()); - for (unsigned i = 0; i < II.getNumOperands(); ++i) { + SmallVector orig_ops(II.arg_size()); + for (unsigned i = 0; i < II.arg_size(); ++i) orig_ops[i] = II.getOperand(i); - } if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops)) return; - todo(II); + + bool hasFromType = false; + auto newI = cast(getNewFromOriginal(&II)); + IRBuilder<> B(newI); + SmallVector new_ops(II.arg_size()); + for (unsigned i = 0; i < II.arg_size(); ++i) { + if (orig_ops[i]->getType() == getFromType()) { + new_ops[i] = truncate(B, getNewFromOriginal(orig_ops[i])); + hasFromType = true; + } else { + new_ops[i] = getNewFromOriginal(orig_ops[i]); + } + } + Type *retTy = II.getType(); + if (II.getType() == getFromType()) { + hasFromType = true; + retTy = getToType(); + } + + if (!hasFromType) + return; + + // TODO check that the intrinsic is overloaded + + CallInst *intr; + Value *nres = intr = B.CreateIntrinsic(retTy, II.getIntrinsicID(), new_ops, + &II, II.getName()); + if (II.getType() == getFromType()) + nres = expand(B, nres); + intr->copyIRFlags(newI); + newI->replaceAllUsesWith(nres); + newI->eraseFromParent(); + return; } @@ -5104,7 +5137,7 @@ class TruncateGenerator : public llvm::InstVisitor { case Intrinsic::nvvm_ldg_global_f: { auto CI = cast(I.getOperand(1)); visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue())); - return false; + return true; } default: break; @@ -5118,7 +5151,7 @@ class TruncateGenerator : public llvm::InstVisitor { /*isVolatile*/ false, llvm::AtomicOrdering::NotAtomic, SyncScope::SingleThread, /*mask*/ getNewFromOriginal(I.getOperand(3))); - return false; + return true; } if (ID == Intrinsic::masked_load) { auto align0 = cast(I.getOperand(1))->getZExtValue(); @@ -5126,31 +5159,10 @@ class TruncateGenerator : public llvm::InstVisitor { visitLoadLike(I, align, /*mask*/ getNewFromOriginal(I.getOperand(2)), /*orig_maskInit*/ I.getOperand(3)); - return false; - } - - auto called = cast(&I)->getCalledFunction(); - (void)called; - switch (ID) { - // #include "IntrinsicDerivatives.inc" - default: - break; + return true; } - switch (ID) { - case Intrinsic::nvvm_barrier0: - case Intrinsic::nvvm_barrier0_popc: - case Intrinsic::nvvm_barrier0_and: - case Intrinsic::nvvm_barrier0_or: - case Intrinsic::nvvm_membar_cta: - case Intrinsic::nvvm_membar_gl: - case Intrinsic::nvvm_membar_sys: - case Intrinsic::amdgcn_s_barrier: - return false; - default: - break; - } - return true; + return false; } llvm::Value *getNewFromOriginal(llvm::Value *v) { diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll new file mode 100644 index 000000000000..dbc81f25adb8 --- /dev/null +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -0,0 +1,23 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s + +declare double @llvm.pow.f64(double %Val, double %Power) +declare double @llvm.powi.f64.i16(double %Val, i16 %power) +declare void @llvm.nvvm.barrier0() + +define double @f(double %x, double %y) { + %res1 = call double @llvm.pow.f64(double %x, double %y) + %res2 = call double @llvm.powi.f64.i16(double %x, i16 2) + %res = fadd double %res1, %res2 + call void @llvm.nvvm.barrier0() + ret double %res +} + +declare double (double, double)* @__enzyme_truncate(...) + +define double @tester(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate(double (double, double)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y) + ret double %res +} From 67c54f77ba7a0af3eba8724495eeb493dcd0f2d1 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Fri, 12 Jan 2024 05:41:11 +0900 Subject: [PATCH 10/13] Add checks --- enzyme/test/Enzyme/Truncate/intrinsic.ll | 39 ++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index dbc81f25adb8..1c3e9573db04 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -21,3 +21,42 @@ entry: %res = call double %ptr(double %x, double %y) ret double %res } + +; CHECK: define internal double @trunc_64_32f(double %x, double %y) { +; CHECK-NEXT: %1 = alloca double, align 8 +; CHECK-NEXT: store double %x, double* %1, align 8 +; CHECK-NEXT: %2 = bitcast double* %1 to float* +; CHECK-NEXT: %3 = load float, float* %2, align 4 +; CHECK-NEXT: store double %y, double* %1, align 8 +; CHECK-NEXT: %4 = bitcast double* %1 to float* +; CHECK-NEXT: %5 = load float, float* %4, align 4 +; CHECK-NEXT: %res11 = call float @llvm.pow.f32(float %3, float %5) +; CHECK-NEXT: %6 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %6, align 4 +; CHECK-NEXT: %7 = bitcast double* %1 to float* +; CHECK-NEXT: store float %res11, float* %7, align 4 +; CHECK-NEXT: %8 = load double, double* %1, align 8 +; CHECK-NEXT: store double %x, double* %1, align 8 +; CHECK-NEXT: %9 = bitcast double* %1 to float* +; CHECK-NEXT: %10 = load float, float* %9, align 4 +; CHECK-NEXT: %res22 = call float @llvm.powi.f32.i16(float %10, i16 2) +; CHECK-NEXT: %11 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %11, align 4 +; CHECK-NEXT: %12 = bitcast double* %1 to float* +; CHECK-NEXT: store float %res22, float* %12, align 4 +; CHECK-NEXT: %13 = load double, double* %1, align 8 +; CHECK-NEXT: store double %8, double* %1, align 8 +; CHECK-NEXT: %14 = bitcast double* %1 to float* +; CHECK-NEXT: %15 = load float, float* %14, align 4 +; CHECK-NEXT: store double %13, double* %1, align 8 +; CHECK-NEXT: %16 = bitcast double* %1 to float* +; CHECK-NEXT: %17 = load float, float* %16, align 4 +; CHECK-NEXT: %res = fadd float %15, %17 +; CHECK-NEXT: %18 = bitcast double* %1 to i64* +; CHECK-NEXT: store i64 0, i64* %18, align 4 +; CHECK-NEXT: %19 = bitcast double* %1 to float* +; CHECK-NEXT: store float %res, float* %19, align 4 +; CHECK-NEXT: %20 = load double, double* %1, align 8 +; CHECK-NEXT: call void @llvm.nvvm.barrier0() +; CHECK-NEXT: ret double %20 +; CHECK-NEXT: } From 60750aa0e7ae3db5c3dcbda025c5a335add99306 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Fri, 12 Jan 2024 06:46:35 +0900 Subject: [PATCH 11/13] Fix older LLVM versions --- enzyme/Enzyme/EnzymeLogic.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 89a1259f41b3..31e4b45ca7b9 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5094,8 +5094,29 @@ class TruncateGenerator : public llvm::InstVisitor { // TODO check that the intrinsic is overloaded CallInst *intr; +#if LLVM_VERSION_MAJOR >= 16 Value *nres = intr = B.CreateIntrinsic(retTy, II.getIntrinsicID(), new_ops, &II, II.getName()); +#else + // Older version do not automatically mangle the intrinsic for us - we need + // to provide the types to mangle with + SmallVector Table; + getIntrinsicInfoTableEntries(II.getIntrinsicID(), Table); + ArrayRef TableRef = Table; + SmallVector ArgTys; + Intrinsic::MatchIntrinsicTypesResult Res = + Intrinsic::matchIntrinsicSignature(II.getFunctionType(), TableRef, + ArgTys); + assert(Res != Intrinsic::MatchIntrinsicTypes_NoMatchRet && + "Intrinsic has incorrect return type!"); + assert(Res != Intrinsic::MatchIntrinsicTypes_NoMatchArg && + "Intrinsic has incorrect argument type!"); + for (unsigned i = 0; i < ArgTys.size(); i++) + if (ArgTys[i] == getFromType()) + ArgTys[i] = getToType(); + Value *nres = intr = B.CreateIntrinsic(II.getIntrinsicID(), ArgTys, new_ops, + &II, II.getName()); +#endif if (II.getType() == getFromType()) nres = expand(B, nres); intr->copyIRFlags(newI); From 9b0e9a9b72195d04193b512219701c3ad25fe08a Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Fri, 12 Jan 2024 09:30:26 +0900 Subject: [PATCH 12/13] Limit test to llvm > 12 --- enzyme/test/Enzyme/Truncate/intrinsic.ll | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index 1c3e9573db04..ea92f5d96bbc 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -1,5 +1,5 @@ -; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi -; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s +; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi +; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi declare double @llvm.pow.f64(double %Val, double %Power) declare double @llvm.powi.f64.i16(double %Val, i16 %power) From 89ed081a601a2e1c10bb9b96d4991f3748e40911 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Fri, 12 Jan 2024 10:19:05 +0900 Subject: [PATCH 13/13] Util function for intrinsic creation --- enzyme/Enzyme/EnzymeLogic.cpp | 25 ++----------------------- enzyme/Enzyme/Utils.cpp | 32 ++++++++++++++++++++++++++++++++ enzyme/Enzyme/Utils.h | 6 ++++++ 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 31e4b45ca7b9..5000ed34bcaa 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5094,29 +5094,8 @@ class TruncateGenerator : public llvm::InstVisitor { // TODO check that the intrinsic is overloaded CallInst *intr; -#if LLVM_VERSION_MAJOR >= 16 - Value *nres = intr = B.CreateIntrinsic(retTy, II.getIntrinsicID(), new_ops, - &II, II.getName()); -#else - // Older version do not automatically mangle the intrinsic for us - we need - // to provide the types to mangle with - SmallVector Table; - getIntrinsicInfoTableEntries(II.getIntrinsicID(), Table); - ArrayRef TableRef = Table; - SmallVector ArgTys; - Intrinsic::MatchIntrinsicTypesResult Res = - Intrinsic::matchIntrinsicSignature(II.getFunctionType(), TableRef, - ArgTys); - assert(Res != Intrinsic::MatchIntrinsicTypes_NoMatchRet && - "Intrinsic has incorrect return type!"); - assert(Res != Intrinsic::MatchIntrinsicTypes_NoMatchArg && - "Intrinsic has incorrect argument type!"); - for (unsigned i = 0; i < ArgTys.size(); i++) - if (ArgTys[i] == getFromType()) - ArgTys[i] = getToType(); - Value *nres = intr = B.CreateIntrinsic(II.getIntrinsicID(), ArgTys, new_ops, - &II, II.getName()); -#endif + Value *nres = intr = createIntrinsicCall(B, II.getIntrinsicID(), retTy, + new_ops, &II, II.getName()); if (II.getType() == getFromType()) nres = expand(B, nres); intr->copyIRFlags(newI); diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 455346addf21..a177dfc7fad4 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2825,3 +2825,35 @@ bool collectOffset(GEPOperator *gep, const DataLayout &DL, unsigned BitWidth, return true; #endif } + +llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B, + llvm::Intrinsic::ID ID, llvm::Type *RetTy, + llvm::ArrayRef Args, + llvm::Instruction *FMFSource, + const llvm::Twine &Name) { +#if LLVM_VERSION_MAJOR >= 16 + llvm::CallInst *nres = B.CreateIntrinsic(RetTy, ID, Args, FMFSource, Name); +#else + SmallVector Table; + Intrinsic::getIntrinsicInfoTableEntries(ID, Table); + ArrayRef TableRef(Table); + + SmallVector ArgTys; + ArgTys.reserve(Args.size()); + for (auto &I : Args) + ArgTys.push_back(I->getType()); + FunctionType *FTy = FunctionType::get(RetTy, ArgTys, false); + SmallVector OverloadTys; + Intrinsic::MatchIntrinsicTypesResult Res = + matchIntrinsicSignature(FTy, TableRef, OverloadTys); + (void)Res; + assert(Res == Intrinsic::MatchIntrinsicTypes_Match && TableRef.empty() && + "Wrong types for intrinsic!"); + Function *Fn = Intrinsic::getDeclaration(B.GetInsertPoint()->getModule(), ID, + OverloadTys); + CallInst *nres = B.CreateCall(Fn, Args, {}, Name); + if (FMFSource) + nres->copyFastMathFlags(FMFSource); +#endif + return nres; +} diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index b7b7b5f31f29..32401d2f17e5 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1807,4 +1807,10 @@ bool collectOffset(llvm::GEPOperator *gep, const llvm::DataLayout &DL, unsigned BitWidth, llvm::MapVector &VariableOffsets, llvm::APInt &ConstantOffset); + +llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &B, + llvm::Intrinsic::ID ID, llvm::Type *RetTy, + llvm::ArrayRef Args, + llvm::Instruction *FMFSource = nullptr, + const llvm::Twine &Name = ""); #endif // ENZYME_UTILS_H