From 92ba9319b150ba5a77bf015b3c507972b3b6ab73 Mon Sep 17 00:00:00 2001 From: Brant-Skywalker Date: Mon, 9 Dec 2024 12:41:17 -0600 Subject: [PATCH] experimental expr JIT evaluation --- enzyme/Enzyme/Herbie.cpp | 135 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 130 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp index 7a681b03fa7..b669b7924a1 100644 --- a/enzyme/Enzyme/Herbie.cpp +++ b/enzyme/Enzyme/Herbie.cpp @@ -13,6 +13,9 @@ #include "llvm/Demangle/Demangle.h" +#include "llvm/ExecutionEngine/Orc/LLJIT.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" + #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" @@ -20,10 +23,12 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" #include "llvm/Support/InstructionCost.h" +#include "llvm/Support/JSON.h" #include "llvm/Support/Program.h" +#include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" -#include #include "llvm/Pass.h" @@ -91,7 +96,10 @@ static cl::opt HerbieTimeout("herbie-timeout", cl::init(120), cl::Hidden, "candidate expressions.")); static cl::opt FPOptCachePath("fpopt-cache-path", cl::init(""), cl::Hidden, - cl::desc("Experimental: path to cache Herbie results")); + cl::desc("Path to cache Herbie results")); +static cl::opt + FPOptEnableJIT("fpopt-enable-jit", cl::init(true), cl::Hidden, + cl::desc("Experimental: Use JIT in candidate evaluation")); static cl::opt HerbieNumPoints("herbie-num-pts", cl::init(1024), cl::Hidden, cl::desc("Number of input points Herbie uses to evaluate " @@ -2929,11 +2937,116 @@ class ApplicableFPCC { } }; +void JITExpr( + const std::string &expr, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + const FastMathFlags &FMF, ArrayRef outputs, + const SmallMapVector &inputValues, + SmallVectorImpl &results) { + using namespace llvm::orc; + // llvm::errs() << "JIT'ting " << expr << "\n"; + + SmallSet argStrSet; + getUniqueArgs(expr, argStrSet); + + size_t NumInputs = argStrSet.size(); + size_t NumOutputs = 1; + + auto parsedNode = parseHerbieExpr(expr, valueToNodeMap, symbolToValueMap); + + auto TSCtx = std::make_unique(); + LLVMContext &Ctx = *TSCtx; + std::unique_ptr UniqueM = std::make_unique("jit_module", Ctx); + + Type *Int64Ty = Type::getInt64Ty(Ctx); + Type *DoubleTy = Type::getDoubleTy(Ctx); + Type *DoublePtrTy = Type::getDoublePtrTy(Ctx); + + FunctionType *FT = + FunctionType::get(Type::getVoidTy(Ctx), + {Int64Ty, Int64Ty, DoublePtrTy, DoublePtrTy}, false); + Function *JitFunc = Function::Create(FT, Function::ExternalLinkage, + "tempExpr", UniqueM.get()); + + auto ArgIt = JitFunc->arg_begin(); + Value *NInVal = &*ArgIt++; + NInVal->setName("numInputs"); + Value *NOutVal = &*ArgIt++; + NOutVal->setName("numOutputs"); + Value *InArr = &*ArgIt++; + InArr->setName("inputs"); + Value *OutArr = &*ArgIt++; + OutArr->setName("outputs"); + + BasicBlock *entry = BasicBlock::Create(Ctx, "entry", JitFunc); + Instruction *ReturnInst = ReturnInst::Create(Ctx, entry); + IRBuilder<> builder(ReturnInst); + builder.setFastMathFlags(FMF); + + std::vector argNames(argStrSet.begin(), argStrSet.end()); + std::unordered_map argIndexMap; + for (unsigned i = 0; i < argNames.size(); i++) + argIndexMap[argNames[i]] = i; + + // Load input values from the input array + ValueToValueMapTy VMap; + for (auto &kv : symbolToValueMap) { + const std::string &sym = kv.first; + Value *origVal = kv.second; + if (argIndexMap.count(sym)) { + Value *Index = builder.getInt64(argIndexMap[sym]); + Value *Ptr = builder.CreateGEP(DoubleTy, InArr, Index, sym + "_ptr"); + Value *Loaded = builder.CreateLoad(DoubleTy, Ptr, sym); + VMap[origVal] = Loaded; + } + } + + // Materialize the expression + Value *Expr = parsedNode->getLLValue(builder, &VMap); + + // Store the result in the output array + Value *OutPtr = + builder.CreateGEP(DoubleTy, OutArr, {builder.getInt64(0)}, "out_ptr"); + builder.CreateStore(Expr, OutPtr); + + // JitFunc->print(llvm::errs()); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + + auto J = cantFail(LLJITBuilder().create()); + ThreadSafeModule TSM(std::move(UniqueM), std::move(TSCtx)); + cantFail(J->addIRModule(std::move(TSM))); + + auto Sym = cantFail(J->lookup("tempExpr")); + + using JitFuncTy = void (*)(int64_t, int64_t, const double *, double *); + auto *JitFuncPtr = Sym.toPtr(); + + std::vector inputVals(NumInputs, 0.0); + for (unsigned i = 0; i < argNames.size(); i++) { + Value *argVal = symbolToValueMap[argNames[i]]; + auto it = inputValues.find(argVal); + assert(it != inputValues.end() && + "Missing input value for a required argument!"); + inputVals[i] = it->second; + } + + std::vector outputVals(NumOutputs, 0.0); + + JitFuncPtr((int64_t)NumInputs, (int64_t)NumOutputs, inputVals.data(), + outputVals.data()); + + results.clear(); + results.append(outputVals.begin(), outputVals.end()); +} + void setUnifiedAccuracyCost( ApplicableOutput &AO, std::unordered_map> &valueToNodeMap, std::unordered_map &symbolToValueMap) { - SmallVector, 4> sampledPoints; getSampledPoints(AO.component->inputs.getArrayRef(), valueToNodeMap, symbolToValueMap, sampledPoints); @@ -2951,7 +3064,13 @@ void setUnifiedAccuracyCost( // llvm::errs() << "DEBUG AO gold value: " << goldVal << "\n"; goldVals[pair.index()] = goldVal; - getFPValues(outputs, pair.value(), results); + if (FPOptEnableJIT) { + JITExpr(AO.expr, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags(), outputs, + pair.value(), results); + } else { + getFPValues(outputs, pair.value(), results); + } double realVal = results[0]; // llvm::errs() << "DEBUG AO real value: " << realVal << "\n"; @@ -2993,7 +3112,13 @@ void setUnifiedAccuracyCost( ArrayRef outputs = {parsedNode.get()}; SmallVector results; - getFPValues(outputs, pair.value(), results); + if (FPOptEnableJIT) { + JITExpr(expr, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags(), outputs, + pair.value(), results); + } else { + getFPValues(outputs, pair.value(), results); + } double realVal = results[0]; // llvm::errs() << "Real value: " << realVal << "\n";