Skip to content

Commit

Permalink
[BugFix] Remove duplicated definition of MakeMultinomialSampler (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Dec 7, 2021
1 parent 1249af5 commit d91b43f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 25 deletions.
7 changes: 4 additions & 3 deletions src/meta_schedule/search_strategy/evolutionary_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ inline std::vector<double> PredictNormalizedScore(const std::vector<CachedTrace>

/**************** Evolutionary Search ****************/

// TODO(@zxybazh): Early stopping for small search space, including deduplication.
/*!
* \brief A search strategy that generates measure candidates using evolutionary search.
* \note The algorithm:
Expand All @@ -268,7 +269,6 @@ inline std::vector<double> PredictNormalizedScore(const std::vector<CachedTrace>
* pick `k = num_measures_per_iter * eps_greedy ` from `init`
* do the measurement on `chosen` & update the cost model
*
* Todo: (@zxybazh): Early stopping for small search space, including deduplication.
*/
class EvolutionarySearchNode : public SearchStrategyNode {
public:
Expand Down Expand Up @@ -489,7 +489,7 @@ inline std::vector<CachedTrace> EvolutionarySearchNode::State::PickBestFromDatab
inline std::vector<CachedTrace> EvolutionarySearchNode::State::SampleInitPopulation(int num) {
// Pick unmeasured states
std::vector<CachedTrace> results(num);
auto f_proc_unmeasured = [this, &results, &num](int thread_id, int trace_id) -> void {
auto f_proc_unmeasured = [this, &results](int thread_id, int trace_id) -> void {
TRandState& rand_state = self->per_thread_data_[thread_id].rand_state;
const IRModule& mod = self->per_thread_data_[thread_id].mod;
CachedTrace& result = results[trace_id];
Expand Down Expand Up @@ -574,7 +574,8 @@ std::vector<CachedTrace> EvolutionarySearchNode::State::EvolveWithCostModel(
if (Optional<Mutator> opt_mutator = mutator_sampler()) {
// Decision: mutate
Mutator mutator = opt_mutator.value();
if (Optional<tir::Trace> opt_new_trace = mutator->Apply(ctrace.sch->trace().value())) {
if (Optional<tir::Trace> opt_new_trace =
mutator->Apply(ctrace.sch->trace().value(), &rand_state)) {
tir::Trace new_trace = opt_new_trace.value();
if (Optional<tir::Schedule> opt_sch =
ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) {
Expand Down
22 changes: 0 additions & 22 deletions src/tir/schedule/primitive/sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,28 +125,6 @@ struct PrimeTable {
}
};

std::function<int32_t()> MakeMultinomialSampler(
support::LinearCongruentialEngine::TRandState* rand_state, const std::vector<double>& weights) {
std::vector<double> sums;
sums.reserve(weights.size());
double sum = 0.0;
for (double w : weights) {
sums.push_back(sum += w);
}
std::uniform_real_distribution<double> dist(0.0, sum);
auto sampler = [rand_state = support::LinearCongruentialEngine(rand_state).ForkSeed(),
dist = std::move(dist), sums = std::move(sums)]() mutable -> int32_t {
support::LinearCongruentialEngine rand_(&rand_state);
double p = dist(rand_);
int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin();
int32_t n = sums.size();
CHECK_LE(0, idx);
CHECK_LE(idx, n);
return (idx == n) ? (n - 1) : idx;
};
return sampler;
}

int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive,
int32_t max_exclusive) {
CHECK(min_inclusive < max_exclusive)
Expand Down

0 comments on commit d91b43f

Please sign in to comment.