From c67f13dad03827f69179da95b19ff29d8667cbb4 Mon Sep 17 00:00:00 2001 From: Brant-Skywalker Date: Fri, 8 Nov 2024 03:07:36 -0600 Subject: [PATCH] parallel herbie --- enzyme/Enzyme/Herbie.cpp | 137 ++++++++++++++++++++++----------------- 1 file changed, 76 insertions(+), 61 deletions(-) diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp index 4f3d53d6178..da445f87db7 100644 --- a/enzyme/Enzyme/Herbie.cpp +++ b/enzyme/Enzyme/Herbie.cpp @@ -83,6 +83,9 @@ static cl::opt FPOptEnableHerbie( static cl::opt FPOptEnablePT( "fpopt-enable-pt", cl::init(false), cl::Hidden, cl::desc("Consider precision changes of floating-point expressions")); +static cl::opt HerbieNumThreads("herbie-num-threads", cl::init(16), + cl::Hidden, + cl::desc("Number of threads Herbie uses")); static cl::opt HerbieTimeout("herbie-timeout", cl::init(120), cl::Hidden, cl::desc("Herbie's timeout to use for each " "candidate expressions.")); @@ -2603,7 +2606,7 @@ class ApplicableOutput { std::string expr; double grad; unsigned executions; - const TargetTransformInfo &TTI; + const TargetTransformInfo *TTI; double initialAccCost; // Requires manual initialization InstructionCost initialCompCost; // Requires manual initialization double initialHerbieCost; // Requires manual initialization @@ -2615,7 +2618,7 @@ class ApplicableOutput { double grad, unsigned executions, const TargetTransformInfo &TTI) : component(&component), oldOutput(oldOutput), expr(expr), grad(grad), - executions(executions), TTI(TTI) { + executions(executions), TTI(&TTI) { initialCompCost = getCompCost({oldOutput}, component.inputs, TTI); findErasableInstructions(); } @@ -2662,7 +2665,7 @@ class ApplicableOutput { InstructionCost erasableCost = 0; for (auto *I : erasableInsts) { - erasableCost += getInstructionCompCost(I, TTI); + erasableCost += getInstructionCompCost(I, *TTI); } return (candidates[candidateIndex].CompCost - erasableCost) * executions; @@ -3094,7 +3097,8 @@ void setUnifiedAccuracyCost( } bool improveViaHerbie( - const std::string &inputExpr, ApplicableOutput &AO, Module *M, + const std::vector &inputExprs, + std::vector &AOs, Module *M, const TargetTransformInfo &TTI, std::unordered_map> &valueToNodeMap, std::unordered_map &symbolToValueMap) { @@ -3105,6 +3109,7 @@ bool improveViaHerbie( Program, "report", "--seed", std::to_string(FPOptRandomSeed), "--timeout", std::to_string(HerbieTimeout), + "--threads", std::to_string(HerbieNumThreads), "--num-points", std::to_string(HerbieNumPoints), "--num-iters", std::to_string(HerbieNumIters)}; @@ -3160,9 +3165,8 @@ bool improveViaHerbie( BaseArgsList.push_back(BaseArgs); } - bool InitialValuesSet = false; - std::unordered_set seenExprs; + bool success = false; for (const auto &BaseArgs : BaseArgsList) { SmallString<32> tmpin, tmpout; @@ -3187,7 +3191,9 @@ bool improveViaHerbie( llvm::sys::fs::remove(tmpout); continue; } - input << inputExpr; + for (const auto &expr : inputExprs) { + input << expr << "\n"; + } input.close(); SmallVector Args = BaseArgs; @@ -3238,71 +3244,74 @@ bool improveViaHerbie( json::Object *obj = parsed->getAsObject(); json::Array &tests = *obj->getArray("tests"); - StringRef bestExpr = tests[0].getAsObject()->getString("output").getValue(); - if (bestExpr == "#f") { - continue; - } + assert(tests.size() == AOs.size() && + "improveViaHerbie: Size mismatch between number of tests and AOs"); - if (seenExprs.count(bestExpr.str()) != 0) { - continue; // Expression already seen, skip it - } - seenExprs.insert(bestExpr.str()); + for (size_t i = 0; i < tests.size(); ++i) { + auto &test = *tests[i].getAsObject(); - double bits = tests[0].getAsObject()->getNumber("bits").getValue(); - json::Array &costAccuracy = - *tests[0].getAsObject()->getArray("cost-accuracy"); + StringRef bestExpr = test.getString("output").getValue(); - json::Array &initial = *costAccuracy[0].getAsArray(); - double initialCostVal = initial[0].getAsNumber().getValue(); - double initialCost = 1.0; - double initialAccuracy = 1.0 - initial[1].getAsNumber().getValue() / bits; + if (bestExpr == "#f") { + continue; + } + + double bits = test.getNumber("bits").getValue(); + json::Array &costAccuracy = *test.getArray("cost-accuracy"); + + json::Array &initial = *costAccuracy[0].getAsArray(); + double initialCostVal = initial[0].getAsNumber().getValue(); + double initialCost = 1.0; + double initialAccuracy = 1.0 - initial[1].getAsNumber().getValue() / bits; + + ApplicableOutput &AO = AOs[i]; - if (!InitialValuesSet) { AO.initialHerbieCost = initialCost; AO.initialHerbieAccuracy = initialAccuracy; - InitialValuesSet = true; - } - json::Array &best = *costAccuracy[1].getAsArray(); - double bestCost = best[0].getAsNumber().getValue() / initialCostVal; - double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits; + json::Array &best = *costAccuracy[1].getAsArray(); + double bestCost = best[0].getAsNumber().getValue() / initialCostVal; + double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits; - RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str()); - bestCandidate.CompCost = - getCompCost(bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap, - cast(AO.oldOutput)->getFastMathFlags()); - AO.candidates.push_back(bestCandidate); + RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str()); + bestCandidate.CompCost = + getCompCost(bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags()); + AO.candidates.push_back(bestCandidate); - json::Array &alternatives = *costAccuracy[2].getAsArray(); + json::Array &alternatives = *costAccuracy[2].getAsArray(); - // Handle alternatives - for (size_t i = 0; i < alternatives.size(); ++i) { - json::Array &entry = *alternatives[i].getAsArray(); - StringRef expr = entry[2].getAsString().getValue(); + std::unordered_set seenExprs; + seenExprs.insert(bestExpr.str()); - if (seenExprs.count(expr.str()) != 0) { - continue; + // Handle alternatives + for (size_t j = 0; j < alternatives.size(); ++j) { + json::Array &entry = *alternatives[j].getAsArray(); + StringRef expr = entry[2].getAsString().getValue(); + + if (seenExprs.count(expr.str()) != 0) { + continue; + } + seenExprs.insert(expr.str()); + + double cost = entry[0].getAsNumber().getValue() / initialCostVal; + double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits; + + RewriteCandidate candidate(cost, accuracy, expr.str()); + candidate.CompCost = + getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags()); + AO.candidates.push_back(candidate); } - seenExprs.insert(expr.str()); - double cost = entry[0].getAsNumber().getValue() / initialCostVal; - double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits; + setUnifiedAccuracyCost(AO, valueToNodeMap, symbolToValueMap); - RewriteCandidate candidate(cost, accuracy, expr.str()); - candidate.CompCost = - getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap, - cast(AO.oldOutput)->getFastMathFlags()); - AO.candidates.push_back(candidate); + success = true; } } - if (AO.candidates.empty()) { - return false; - } - - setUnifiedAccuracyCost(AO, valueToNodeMap, symbolToValueMap); - return true; + return success; } std::string getHerbieOperator(const Instruction &I) { @@ -4266,6 +4275,9 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) { << *input << "\n"; } + std::vector herbieInputs; + std::vector newAOs; + assert(component.outputs.size() > 0 && "No outputs found for component"); for (auto &output : component.outputs) { // 3) run fancy opts @@ -4310,16 +4322,19 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) { if (EnzymePrintHerbie) llvm::errs() << "Herbie input:\n" << herbieInput << "\n"; + herbieInputs.push_back(herbieInput); ApplicableOutput AO(component, output, expr, grad, executions, TTI); - if (!improveViaHerbie(herbieInput, AO, F.getParent(), TTI, - valueToNodeMap, symbolToValueMap)) { - if (EnzymePrintHerbie) - llvm::errs() << "Failed to optimize an expression using Herbie!\n"; - continue; - } + newAOs.push_back(std::move(AO)); + } - AOs.push_back(std::move(AO)); + if (!improveViaHerbie(herbieInputs, newAOs, F.getParent(), TTI, + valueToNodeMap, symbolToValueMap)) { + if (EnzymePrintHerbie) + llvm::errs() << "Failed to optimize expressions using Herbie!\n"; } + + AOs.insert(AOs.end(), std::make_move_iterator(newAOs.begin()), + std::make_move_iterator(newAOs.end())); } if (FPOptEnablePT) {