Skip to content

Commit

Permalink
parallel herbie
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Nov 8, 2024
1 parent 6631026 commit c67f13d
Showing 1 changed file with 76 additions and 61 deletions.
137 changes: 76 additions & 61 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ static cl::opt<bool> FPOptEnableHerbie(
static cl::opt<bool> FPOptEnablePT(
"fpopt-enable-pt", cl::init(false), cl::Hidden,
cl::desc("Consider precision changes of floating-point expressions"));
static cl::opt<int> HerbieNumThreads("herbie-num-threads", cl::init(16),
cl::Hidden,
cl::desc("Number of threads Herbie uses"));
static cl::opt<int> HerbieTimeout("herbie-timeout", cl::init(120), cl::Hidden,
cl::desc("Herbie's timeout to use for each "
"candidate expressions."));
Expand Down Expand Up @@ -2603,7 +2606,7 @@ class ApplicableOutput {
std::string expr;
double grad;
unsigned executions;
const TargetTransformInfo &TTI;
const TargetTransformInfo *TTI;
double initialAccCost; // Requires manual initialization
InstructionCost initialCompCost; // Requires manual initialization
double initialHerbieCost; // Requires manual initialization
Expand All @@ -2615,7 +2618,7 @@ class ApplicableOutput {
double grad, unsigned executions,
const TargetTransformInfo &TTI)
: component(&component), oldOutput(oldOutput), expr(expr), grad(grad),
executions(executions), TTI(TTI) {
executions(executions), TTI(&TTI) {
initialCompCost = getCompCost({oldOutput}, component.inputs, TTI);
findErasableInstructions();
}
Expand Down Expand Up @@ -2662,7 +2665,7 @@ class ApplicableOutput {
InstructionCost erasableCost = 0;

for (auto *I : erasableInsts) {
erasableCost += getInstructionCompCost(I, TTI);
erasableCost += getInstructionCompCost(I, *TTI);
}

return (candidates[candidateIndex].CompCost - erasableCost) * executions;
Expand Down Expand Up @@ -3094,7 +3097,8 @@ void setUnifiedAccuracyCost(
}

bool improveViaHerbie(
const std::string &inputExpr, ApplicableOutput &AO, Module *M,
const std::vector<std::string> &inputExprs,
std::vector<ApplicableOutput> &AOs, Module *M,
const TargetTransformInfo &TTI,
std::unordered_map<Value *, std::shared_ptr<FPNode>> &valueToNodeMap,
std::unordered_map<std::string, Value *> &symbolToValueMap) {
Expand All @@ -3105,6 +3109,7 @@ bool improveViaHerbie(
Program, "report",
"--seed", std::to_string(FPOptRandomSeed),
"--timeout", std::to_string(HerbieTimeout),
"--threads", std::to_string(HerbieNumThreads),
"--num-points", std::to_string(HerbieNumPoints),
"--num-iters", std::to_string(HerbieNumIters)};

Expand Down Expand Up @@ -3160,9 +3165,8 @@ bool improveViaHerbie(
BaseArgsList.push_back(BaseArgs);
}

bool InitialValuesSet = false;

std::unordered_set<std::string> seenExprs;
bool success = false;

for (const auto &BaseArgs : BaseArgsList) {
SmallString<32> tmpin, tmpout;
Expand All @@ -3187,7 +3191,9 @@ bool improveViaHerbie(
llvm::sys::fs::remove(tmpout);
continue;
}
input << inputExpr;
for (const auto &expr : inputExprs) {
input << expr << "\n";
}
input.close();

SmallVector<llvm::StringRef> Args = BaseArgs;
Expand Down Expand Up @@ -3238,71 +3244,74 @@ bool improveViaHerbie(

json::Object *obj = parsed->getAsObject();
json::Array &tests = *obj->getArray("tests");
StringRef bestExpr = tests[0].getAsObject()->getString("output").getValue();

if (bestExpr == "#f") {
continue;
}
assert(tests.size() == AOs.size() &&
"improveViaHerbie: Size mismatch between number of tests and AOs");

if (seenExprs.count(bestExpr.str()) != 0) {
continue; // Expression already seen, skip it
}
seenExprs.insert(bestExpr.str());
for (size_t i = 0; i < tests.size(); ++i) {
auto &test = *tests[i].getAsObject();

double bits = tests[0].getAsObject()->getNumber("bits").getValue();
json::Array &costAccuracy =
*tests[0].getAsObject()->getArray("cost-accuracy");
StringRef bestExpr = test.getString("output").getValue();

json::Array &initial = *costAccuracy[0].getAsArray();
double initialCostVal = initial[0].getAsNumber().getValue();
double initialCost = 1.0;
double initialAccuracy = 1.0 - initial[1].getAsNumber().getValue() / bits;
if (bestExpr == "#f") {
continue;
}

double bits = test.getNumber("bits").getValue();
json::Array &costAccuracy = *test.getArray("cost-accuracy");

json::Array &initial = *costAccuracy[0].getAsArray();
double initialCostVal = initial[0].getAsNumber().getValue();
double initialCost = 1.0;
double initialAccuracy = 1.0 - initial[1].getAsNumber().getValue() / bits;

ApplicableOutput &AO = AOs[i];

if (!InitialValuesSet) {
AO.initialHerbieCost = initialCost;
AO.initialHerbieAccuracy = initialAccuracy;
InitialValuesSet = true;
}

json::Array &best = *costAccuracy[1].getAsArray();
double bestCost = best[0].getAsNumber().getValue() / initialCostVal;
double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits;
json::Array &best = *costAccuracy[1].getAsArray();
double bestCost = best[0].getAsNumber().getValue() / initialCostVal;
double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits;

RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str());
bestCandidate.CompCost =
getCompCost(bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap,
cast<Instruction>(AO.oldOutput)->getFastMathFlags());
AO.candidates.push_back(bestCandidate);
RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str());
bestCandidate.CompCost =
getCompCost(bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap,
cast<Instruction>(AO.oldOutput)->getFastMathFlags());
AO.candidates.push_back(bestCandidate);

json::Array &alternatives = *costAccuracy[2].getAsArray();
json::Array &alternatives = *costAccuracy[2].getAsArray();

// Handle alternatives
for (size_t i = 0; i < alternatives.size(); ++i) {
json::Array &entry = *alternatives[i].getAsArray();
StringRef expr = entry[2].getAsString().getValue();
std::unordered_set<std::string> seenExprs;
seenExprs.insert(bestExpr.str());

if (seenExprs.count(expr.str()) != 0) {
continue;
// Handle alternatives
for (size_t j = 0; j < alternatives.size(); ++j) {
json::Array &entry = *alternatives[j].getAsArray();
StringRef expr = entry[2].getAsString().getValue();

if (seenExprs.count(expr.str()) != 0) {
continue;
}
seenExprs.insert(expr.str());

double cost = entry[0].getAsNumber().getValue() / initialCostVal;
double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits;

RewriteCandidate candidate(cost, accuracy, expr.str());
candidate.CompCost =
getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap,
cast<Instruction>(AO.oldOutput)->getFastMathFlags());
AO.candidates.push_back(candidate);
}
seenExprs.insert(expr.str());

double cost = entry[0].getAsNumber().getValue() / initialCostVal;
double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits;
setUnifiedAccuracyCost(AO, valueToNodeMap, symbolToValueMap);

RewriteCandidate candidate(cost, accuracy, expr.str());
candidate.CompCost =
getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap,
cast<Instruction>(AO.oldOutput)->getFastMathFlags());
AO.candidates.push_back(candidate);
success = true;
}
}

if (AO.candidates.empty()) {
return false;
}

setUnifiedAccuracyCost(AO, valueToNodeMap, symbolToValueMap);
return true;
return success;
}

std::string getHerbieOperator(const Instruction &I) {
Expand Down Expand Up @@ -4266,6 +4275,9 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) {
<< *input << "\n";
}

std::vector<std::string> herbieInputs;
std::vector<ApplicableOutput> newAOs;

assert(component.outputs.size() > 0 && "No outputs found for component");
for (auto &output : component.outputs) {
// 3) run fancy opts
Expand Down Expand Up @@ -4310,16 +4322,19 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) {
if (EnzymePrintHerbie)
llvm::errs() << "Herbie input:\n" << herbieInput << "\n";

herbieInputs.push_back(herbieInput);
ApplicableOutput AO(component, output, expr, grad, executions, TTI);
if (!improveViaHerbie(herbieInput, AO, F.getParent(), TTI,
valueToNodeMap, symbolToValueMap)) {
if (EnzymePrintHerbie)
llvm::errs() << "Failed to optimize an expression using Herbie!\n";
continue;
}
newAOs.push_back(std::move(AO));
}

AOs.push_back(std::move(AO));
if (!improveViaHerbie(herbieInputs, newAOs, F.getParent(), TTI,
valueToNodeMap, symbolToValueMap)) {
if (EnzymePrintHerbie)
llvm::errs() << "Failed to optimize expressions using Herbie!\n";
}

AOs.insert(AOs.end(), std::make_move_iterator(newAOs.begin()),
std::make_move_iterator(newAOs.end()));
}

if (FPOptEnablePT) {
Expand Down

0 comments on commit c67f13d

Please sign in to comment.