diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp index a89ae97a2b4..61364fe4435 100644 --- a/enzyme/Enzyme/Herbie.cpp +++ b/enzyme/Enzyme/Herbie.cpp @@ -3165,7 +3165,7 @@ bool improveViaHerbie( BaseArgsList.push_back(BaseArgs); } - std::unordered_set seenExprs; + std::vector> seenExprs; bool success = false; for (const auto &BaseArgs : BaseArgsList) { @@ -3270,30 +3270,31 @@ bool improveViaHerbie( AO.initialHerbieCost = initialCost; AO.initialHerbieAccuracy = initialAccuracy; - json::Array &best = *costAccuracy[1].getAsArray(); - double bestCost = best[0].getAsNumber().getValue() / initialCostVal; - double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits; + if (seenExprs[i].count(bestExpr.str()) == 0) { + seenExprs[i].insert(bestExpr.str()); - RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str()); - bestCandidate.CompCost = - getCompCost(bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap, - cast(AO.oldOutput)->getFastMathFlags()); - AO.candidates.push_back(bestCandidate); + json::Array &best = *costAccuracy[1].getAsArray(); + double bestCost = best[0].getAsNumber().getValue() / initialCostVal; + double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits; - json::Array &alternatives = *costAccuracy[2].getAsArray(); + RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str()); + bestCandidate.CompCost = getCompCost( + bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(AO.oldOutput)->getFastMathFlags()); + AO.candidates.push_back(bestCandidate); + } - std::unordered_set seenExprs; - seenExprs.insert(bestExpr.str()); + json::Array &alternatives = *costAccuracy[2].getAsArray(); // Handle alternatives for (size_t j = 0; j < alternatives.size(); ++j) { json::Array &entry = *alternatives[j].getAsArray(); StringRef expr = entry[2].getAsString().getValue(); - if (seenExprs.count(expr.str()) != 0) { + if (seenExprs[i].count(expr.str()) != 0) { continue; } - seenExprs.insert(expr.str()); + seenExprs[i].insert(expr.str()); double cost = entry[0].getAsNumber().getValue() / initialCostVal; double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits;