From 4eb75cef31ffeb5fc663e4ca6326cd91e18f6854 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 29 Nov 2021 12:19:11 -0800 Subject: [PATCH 01/14] Checkpoint. Fix cost model comment. Finish evolutionary seaarch. Remove extra code. Fix compile. Add comments. Add python part. Ad test. Update other files & comments. --- include/tvm/meta_schedule/cost_model.h | 4 +- include/tvm/meta_schedule/search_strategy.h | 14 + .../meta_schedule/cost_model/cost_model.py | 3 +- .../meta_schedule/search_strategy/__init__.py | 1 + .../search_strategy/evolutionary_search.py | 101 +++ .../search_strategy/evolutionary_search.cc | 601 ++++++++++++++++++ src/meta_schedule/utils.h | 131 +++- src/target/target_kind.cc | 1 + src/tir/schedule/primitive.h | 10 + ...schedule_rule_parallel_vectorize_unroll.py | 1 + .../test_meta_schedule_search_strategy.py | 163 ++++- 11 files changed, 1013 insertions(+), 17 deletions(-) create mode 100644 python/tvm/meta_schedule/search_strategy/evolutionary_search.py create mode 100644 src/meta_schedule/search_strategy/evolutionary_search.cc diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index adc1955a1e..b05dc3c118 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -59,10 +59,10 @@ class CostModelNode : public runtime::Object { const Array& results) = 0; /*! - * \brief Predict the running results of given measure candidates. + * \brief Predict the normalized score (the larger the better) of given measure candidates. * \param tune_context The tuning context. * \param candidates The measure candidates. - * \return The predicted running results. + * \return The predicted normalized score. */ virtual std::vector Predict(const TuneContext& tune_context, const Array& candidates) = 0; diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 3a0fa0ab4a..f2798bf8f9 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -20,6 +20,8 @@ #define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ #include +#include +#include #include #include #include @@ -29,6 +31,7 @@ namespace meta_schedule { // Forward declaration class TuneContext; +class CostModel; /*! \brief The schedule (with input shapes) to be measured. */ class MeasureCandidateNode : public runtime::Object { @@ -255,6 +258,17 @@ class SearchStrategy : public runtime::ObjectRef { */ TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); + TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population, // + double init_measured_ratio, // + int genetic_algo_iters, // + double p_mutate, // + double eps_greedy, // + Map mutator_probs, // + Database database, // + CostModel cost_model); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index 0cbba42a31..13ca203c90 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -87,7 +87,7 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) Return ------ result : np.ndarray - The predicted running results. + The predicted normalized score. """ n = len(candidates) results = np.zeros(shape=(n,), dtype="float64") @@ -117,7 +117,6 @@ def f_save(path: str) -> None: @check_override(self.__class__, CostModel) def f_update( - self, tune_context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index f385b72db4..6102ebc41a 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -23,3 +23,4 @@ from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate from .replay_trace import ReplayTrace, ReplayTraceConfig from .replay_func import ReplayFunc, ReplayFuncConfig +from .evolutionary_search import EvolutionarySearch diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py new file mode 100644 index 0000000000..0b5538ed40 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Evolutionary Search Strategy""" + +from typing import TYPE_CHECKING, Dict + +from tvm._ffi import register_object +from ...tir import FloatImm + +from .search_strategy import SearchStrategy +from ..mutator import Mutator +from ..database import Database + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..cost_model import CostModel + + +@register_object("meta_schedule.EvolutionarySearch") +class EvolutionarySearch(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + population : int + The initial population of traces from measured samples and randomly generated samples. + init_measured_ratio : int + The ratio of measured samples in the initial population. + genetic_algo_iters : int + The number of iterations for genetic algorithm. + p_mutate : float + The probability of mutation. + eps_greedy : float + The ratio of greedy selected samples in the final picks. + mutator_probs: Dict[Mutator, FloatImm] + The probability contribution of all mutators. + database : Database + The database used in the search. + cost_model : CostModel + The cost model used in the search. + """ + + num_trials_per_iter: int + num_trials_total: int + population: int + init_measured_ratio: int + genetic_algo_iters: int + p_mutate: float + eps_greedy: float + mutator_probs: Dict[Mutator, FloatImm] + database: Database + cost_model: "CostModel" + + def __init__( + self, + num_trials_per_iter: int, + num_trials_total: int, + population: int, + init_measured_ratio: float, + genetic_algo_iters: int, + p_mutate: float, + eps_greedy: float, + mutator_probs: Dict[Mutator, FloatImm], + database: Database, + cost_model: "CostModel", + ): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyEvolutionarySearch, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + population, + init_measured_ratio, + genetic_algo_iters, + p_mutate, + eps_greedy, + mutator_probs, + database, + cost_model, + ) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc new file mode 100644 index 0000000000..fd65ab9586 --- /dev/null +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/**************** Data Structure ****************/ + +/*! + * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. + * \note It maintains a min heap in terms of `CachedTrace::score`. Therefore, when + * overflow happens, the element evicted is the one with the min `CachedTrace::score`. + * As time goes, the elements in the heap are going to be larger. + */ +class SizedHeap { + /*! \brief The comparator class, used by `std::push_heap` and `std::pop_heap` */ + struct Comparator { + bool operator()(const CachedTrace& a, const CachedTrace& b) const { return a.score > b.score; } + }; + + public: + /*! + * \brief Constructor + * \param size_limit The up-limit of the heap size + */ + explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); } + + /*! + * \brief Push the specific item to the heap if its key did not appears in the heap + * \param item The item to be pushed + */ + void Push(const CachedTrace& item) { + if (!in_heap.insert(item.repr).second) { + return; + } + int size = heap.size(); + if (size < size_limit) { + // Heap is not full, just push + heap.emplace_back(item); + std::push_heap(heap.begin(), heap.end(), Comparator()); + } else if (Comparator()(item, heap.front())) { + // if the item is better than the worst one in the heap, we can safely kick it out + std::pop_heap(heap.begin(), heap.end(), Comparator()); + heap.back() = item; + std::push_heap(heap.begin(), heap.end(), Comparator()); + } + // Otherwise, the item is worse than any other element in the heap + } + + /*! \brief Up-limit of the heap size */ + int size_limit; + /*! \brief The heap, the worse the topper */ + std::vector heap; + /*! \brief The traces that are in the heap */ + std::unordered_set in_heap; +}; + +/*! + * \brief A search strategy that generates measure candidates using evolutionary search. + * \note The algorithm: + * + * Loop until #measured >= total_measures: + * init = + * pick top `k = population * init_measured_ratio ` from measured + * pick `k = population * (1 - init_measured_ratio)` random selected from search space + * best = generate `population` states with the cost model, + * starting from `init`, + * using mutators, + * and return the top-n states during the search, + * where `n = num_measures_per_iter` + * chosen = pick top `k = num_measures_per_iter * (1 - eps_greedy)` from `best` + * pick `k = num_measures_per_iter * eps_greedy ` from `init` + * do the measurement on `chosen` & update the cost model + */ +class EvolutionarySearchNode : public SearchStrategyNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + EvolutionarySearchNode* self; + /*! \brief The design spaces. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(EvolutionarySearchNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + /*! + * \brief Sample the initial population from previous measured results and randomly generated + * traces via trace replaying. + * \return The initial population of traces sampled. + */ + inline Array SampleInitPopulation(); + /*! + * \brief Evolve the initial population using mutators and samplers. + * \param inits The initial population of traces sampled. + * \return The evolved traces from initial population. + */ + inline Array EvolveWithCostModel(const Array& inits); + /*! + * \brief Pick final candidates from the given initial population and bests of evolved ones. + * \param inits The initial population of traces sampled. + * \param bests The best candidates predicted from evolved traces. + * \return The final picked candidates with a ratio of both. + */ + inline Array PickWithEpsGreedy(const Array& inits, + const Array& bests); + /*! + * \brief Assemble measure candidates from the given candidate traces. + * \param traces The picked candidate traces. + * \return The assembled measure candidates. + */ + inline Array AssembleCandidates(const Array& picks); + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + /*! \brief THe population size in the evolutionary search.*/ + int population; + + /*! \brief The target for the workload. */ + Target target_{nullptr}; + /*! \brief The tuning context of the evolutionary search strategy. */ + TuneContext tune_context_{nullptr}; + /*! \brief The mutators to be used. */ + Array mutators_{nullptr}; + /*! \brief The module to be tuned. */ + Array mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The number of threads to use. -1 means using logical cpu number. */ + int num_threads_ = -1; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + /*** Configuration: the initial population ***/ + /*! \brief The ratio of measured states used in the initial population */ + double init_measured_ratio; + + /*** Configuration: evolution ***/ + /*! \brief The number of iterations performed by generic algorithm. */ + int genetic_algo_iters; + /*! \brief The probability to perform mutation */ + double p_mutate; + /*! \brief Mutators and their probability mass */ + Map mutator_probs{nullptr}; + /*! \brief A Database for selecting useful candidates. */ + Database database{nullptr}; + /*! \brief A cost model helping to explore the search space */ + CostModel cost_model{nullptr}; + /*! \brief The batch of measure candidates generated for measurement. */ + Array candidates{nullptr}; + + /*** Configuration: pick states for measurement ***/ + /*! \brief The ratio of measurements to use randomly sampled states. */ + double eps_greedy; + + /*! + * Helpers + * Note that the use of trace cache could be multi-threaded. + */ + mutable std::unordered_map trace_cache_; + mutable std::mutex trace_cache_mutex_; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `tune_context_` is not visited + // `target_` is not visited + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + + /*** Configuration: global ***/ + v->Visit("num_trials_total", &num_trials_total); + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("population", &population); + /*** Configuration: the initial population ***/ + v->Visit("init_measured_ratio", &init_measured_ratio); + /*** Configuration: evolution ***/ + v->Visit("genetic_algo_iters", &genetic_algo_iters); + v->Visit("p_mutate", &p_mutate); + v->Visit("mutator_probs", &mutator_probs); + v->Visit("cost_model", &cost_model); + /*** Configuration: pick states for measurement ***/ + v->Visit("eps_greedy", &eps_greedy); + /*** Helpers ***/ + // Not visited: `trace_cache_` + // Not visited: `trace_cache_mutex_` + } + + /*! + * \brief Add the cached trace into the trace_cache_ + * \param cached_trace The cached_trace to be added + */ + void _AddCachedTrace(const CachedTrace& cached_trace) const { + // Todo(@zxybazh): Avoid redundent traces + std::unique_lock lock(this->trace_cache_mutex_); + trace_cache_.emplace(GetRef(cached_trace.trace), cached_trace); + } + + /*! + * \brief Retrieve the cached trace given the trace + * \param trace The trace to be retrieved + * \return The cached trace + */ + CachedTrace _GetCachedTrace(const tir::Trace& trace) const { + auto iter = trace_cache_.find(trace); + ICHECK(iter != trace_cache_.end()); + return iter->second; + } + + static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; + TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + CHECK(tune_context.defined()) << "TuneContext must be defined!"; + CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; + CHECK(tune_context->mutators.defined()) << "Mutators must be defined!"; + CHECK(tune_context->target.defined()) << "Target must be defined!"; + this->tune_context_ = tune_context; + this->target_ = tune_context->target.value(); + this->mutators_ = tune_context->mutators.value(); + this->num_threads_ = tune_context->num_threads; + + this->mod_.reserve(this->num_threads_); + for (int i = 0; i < this->num_threads_; i++) { + this->mod_.push_back(DeepCopyIRModule(tune_context->mod.value())); + } + + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this, design_spaces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Array EvolutionarySearchNode::State::SampleInitPopulation() { + self->trace_cache_.clear(); + std::vector results; + results.reserve(self->population); + // Threading RNG + std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); + // Pick measured states + int num_measured = self->population * self->init_measured_ratio; + for (TuningRecord record : + self->database->GetTopK(self->database->CommitWorkload(self->mod_[0]), num_measured)) { + results.push_back(record->trace); + } + + auto f_proc_measured = [this, &results, &per_thread_rand_state](int thread_id, + int trace_id) -> void { + TRandState& rand_state = per_thread_rand_state[trace_id]; + const tir::Trace& trace = results[trace_id]; + if (Optional opt_sch = + meta_schedule::ReplayTrace(trace, self->mod_[trace_id], &rand_state)) { + tir::Schedule sch = opt_sch.value(); + self->_AddCachedTrace(CachedTrace{trace.get(), sch, Repr(sch), -1.0}); + } else { + LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; + throw; + } + }; + support::parallel_for_dynamic(0, results.size(), self->num_threads_, f_proc_measured); + + // Pick unmeasured states + std::atomic tot_fail_ct(0); + std::atomic success_ct(0); + auto f_proc_unmeasured = [this, &results, &per_thread_rand_state, &tot_fail_ct, &success_ct]( + int thread_id, int trace_id) -> void { + TRandState& rand_state = per_thread_rand_state[trace_id]; + for (;;) { + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]->trace().value(); + Map decisions; + try { + if (Optional opt_sch = + meta_schedule::ReplayTrace(trace, self->mod_[trace_id], &rand_state)) { + tir::Schedule sch = opt_sch.value(); + tir::Trace old_trace = sch->trace().value(); + tir::Trace trace(old_trace->insts, old_trace->decisions); + self->_AddCachedTrace(CachedTrace{trace.get(), sch, Repr(sch), -1.0}); + results[trace_id] = std::move(trace); + success_ct++; + break; + } else { + tot_fail_ct++; + } + } catch (const dmlc::Error& e) { + tot_fail_ct++; + } + if (success_ct > 64) { // Todo(@junru): Why 64? Add to constructor. + break; + } + } + }; + num_measured = results.size(); + results.resize(self->population, tir::Trace(nullptr)); + support::parallel_for_dynamic(num_measured, self->population, self->num_threads_, + f_proc_unmeasured); + std::vector pruned_results; + for (const tir::Trace& result : results) { + if (result.defined()) { + pruned_results.push_back(result); + } + } + // LOG(INFO) << "fail count: " << tot_fail_ct; + return pruned_results; +} + +Array EvolutionarySearchNode::State::EvolveWithCostModel( + const Array& inits) { + // The heap to record best schedule, we do not consider schedules that are already measured + // Also we use `in_heap` to make sure items in the heap are de-duplicated + SizedHeap heap(self->num_trials_per_iter); + // Threading RNG + std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); + std::vector> thread_trace_samplers(self->num_threads_); + std::vector()>> thread_mutator_samplers(self->num_threads_); + std::vector trace_used; + std::mutex trace_used_mutex; + // Prepare search queues + std::vector sch_curr; + std::vector sch_next; + sch_curr.reserve(self->population); + sch_next.reserve(self->population); + for (const tir::Trace& trace : inits) { + sch_curr.push_back(self->_GetCachedTrace(trace)); + } + // Main loop: (genetic_algo_iters + 1) times + for (int iter = 0;; ++iter) { + // Predict running time with the cost model, + // and put the schedules with the predicted perf to the heap + std::vector scores = + PredictNormalizedScore(sch_curr, self->tune_context_, self->cost_model, self->args_info_); + for (int i = 0, n = sch_curr.size(); i < n; ++i) { + CachedTrace& entry = sch_curr[i]; + entry.score = scores[i]; + if (!self->database->GetTopK(self->database->CommitWorkload(entry.sch->mod()), 1).size()) { + heap.Push(entry); + } + } + // Discontinue once it reaches end of search + if (iter == self->genetic_algo_iters) { + break; + } + // Set threaded samplers, with probability from predicated normalized throughputs + for (int i = 0; i < self->num_threads_; ++i) { + TRandState& rand_state = per_thread_rand_state[i]; + thread_trace_samplers[i] = MakeMultinomial(rand_state, scores); + thread_mutator_samplers[i] = + MakeMutatorSampler(self->p_mutate, self->mutator_probs, rand_state); + } + trace_used = std::vector(scores.size(), 0); + // The worker function + auto f_find_candidate = [&per_thread_rand_state, &thread_trace_samplers, + &thread_mutator_samplers, &trace_used, &trace_used_mutex, &sch_curr, + &sch_next, this](int thread_id, int i) { + // Prepare samplers + TRandState& rand_state = per_thread_rand_state[thread_id]; + const std::function& trace_sampler = thread_trace_samplers[thread_id]; + const std::function()>& mutator_sampler = + thread_mutator_samplers[thread_id]; + // Loop until success + int max_retry_cnt = 10; + int retry_cnt = 0; + for (;;) { + int trace_idx = trace_sampler(); + const CachedTrace& cached_trace = sch_curr[trace_idx]; + if (Optional opt_mutator = mutator_sampler()) { + // Decision: mutate + Mutator mutator = opt_mutator.value(); + if (Optional opt_new_trace = + mutator->Apply(GetRef(cached_trace.trace))) { + tir::Trace new_trace = opt_new_trace.value(); + if (Optional opt_sch = + ReplayTrace(new_trace, self->mod_[i], &rand_state)) { + tir::Schedule sch = opt_sch.value(); + CachedTrace new_cached_trace{new_trace.get(), sch, Repr(sch), -1.0}; + self->_AddCachedTrace(new_cached_trace); + sch_next[i] = new_cached_trace; + break; + } + } + } else { + // Decision: do not mutate + std::unique_lock lock(trace_used_mutex); + if (!trace_used[trace_idx]) { + trace_used[trace_idx] = 1; + sch_next[i] = cached_trace; + break; + } + } + retry_cnt++; + if (retry_cnt >= max_retry_cnt) { + sch_next[i] = cached_trace; + break; + } + } + }; + sch_next.clear(); + sch_next.resize(self->population); + support::parallel_for_dynamic(0, self->population, 1, f_find_candidate); + sch_curr.clear(); + sch_curr.swap(sch_next); + } + // Return the best states from the heap, sorting from higher score to lower ones + std::sort(heap.heap.begin(), heap.heap.end(), CachedTrace::Compare); + Array results; + results.reserve(self->num_trials_per_iter); + for (const CachedTrace& item : heap.heap) { + results.push_back(GetRef(item.trace)); + } + /* Logging + constexpr int kNumScoresPerLine = 16; + std::ostringstream os; + int n = heap.heap.size(); + for (int st = 0; st < n; st += kNumScoresPerLine) { + os << std::endl; + int ed = std::min(st + kNumScoresPerLine, n); + os << "[" << (st + 1) << " : " << ed << "]:\t"; + for (int i = st; i < ed; ++i) { + if (i != st) { + os << " "; + } + os << std::fixed << std::setprecision(4) << heap.heap[i].score; + } + } + LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); + */ + return results; +} + +Array EvolutionarySearchNode::State::PickWithEpsGreedy(const Array& inits, + const Array& bests) { + int num_rands = self->num_trials_per_iter * self->eps_greedy; + int num_bests = self->num_trials_per_iter - num_rands; + std::vector rands = + tir::SampleWithoutReplacement(&self->rand_state_, inits.size(), inits.size()); + Array results; + results.reserve(self->num_trials_per_iter); + for (int i = 0, i_bests = 0, i_rands = 0; i < self->num_trials_per_iter; ++i) { + bool has_best = i_bests < static_cast(bests.size()); + bool has_rand = i_rands < static_cast(rands.size()); + // Pick a schedule + Optional trace{NullOpt}; + // If needs `bests`, then prefer `bests` + if (i < num_bests) { + if (has_best) { + trace = bests[i_bests++]; + } else if (has_rand) { + trace = inits[rands[i_rands++]]; + } else { + break; + } + } else { + // Else prefer `rands` + if (has_rand) { + trace = inits[rands[i_rands++]]; + } else if (has_best) { + trace = bests[i_bests++]; + } else { + break; + } + } + results.push_back(trace.value()); + } + return results; +} + +inline Array EvolutionarySearchNode::State::AssembleCandidates( + const Array& picks) { + Array measure_inputs; + measure_inputs.reserve(picks.size()); + for (const tir::Trace& pick : picks) { + CachedTrace trace = self->_GetCachedTrace(pick); + measure_inputs.push_back(MeasureCandidate(trace.sch, self->args_info_)); + } + return measure_inputs; +} + +inline Optional> +EvolutionarySearchNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + self->candidates = Array(nullptr); + return NullOpt; + } + if (ed > self->num_trials_total) { + self->num_trials_per_iter += self->num_trials_total - ed; + ed = self->num_trials_total; + } + ICHECK_LT(st, ed); + + // new parts + Array inits = SampleInitPopulation(); + Array bests = EvolveWithCostModel(inits); + Array picks = PickWithEpsGreedy(inits, bests); + self->candidates = AssembleCandidates(picks); + return self->candidates; +} + +inline void EvolutionarySearchNode::State::NotifyRunnerResults(const Array& results) { + // We need to assume the candidates' order are not changed in runner. + ICHECK(self->candidates.defined() && self->candidates.size() == results.size()); + st += results.size(); + ed += results.size(); + int i = 0; + for (const RunnerResult& result : results) { + // Todo: Update to database measure callback + if (result->error_msg.defined() || !result->run_secs.defined()) continue; + self->database->CommitTuningRecord(TuningRecord( + /*trace=*/self->candidates[i]->sch->trace().value(), // + /*run_secs=*/result->run_secs.value(), // + /*workload=*/self->database->CommitWorkload(self->mod_[0]), // + /*target=*/self->target_, // + /*args_info=*/self->candidates[i]->args_info)); + // Todo: Update to cost model measure callback + self->cost_model->Update(self->tune_context_, self->candidates, results); + i++; + } +} + +SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population, // + double init_measured_ratio, // + int genetic_algo_iters, // + double p_mutate, // + double eps_greedy, // + Map mutator_probs, // + Database database, // + CostModel cost_model) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + n->population = population; + n->init_measured_ratio = init_measured_ratio; + n->genetic_algo_iters = genetic_algo_iters; + n->p_mutate = p_mutate; + n->eps_greedy = eps_greedy; + n->mutator_probs = mutator_probs; + n->database = database; + n->cost_model = cost_model; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") + .set_body_typed(SearchStrategy::EvolutionarySearch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index b11d29f63f..12982730a4 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -37,6 +37,7 @@ #include #include +#include #include #include @@ -241,7 +242,7 @@ inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& /*! * \brief Get the number of cores in CPU * \param target The target - * \param + * \return The number of cores. */ inline int GetTargetNumCores(const Target& target) { int num_cores = target->GetAttr("num-cores").value_or(-1); @@ -250,11 +251,8 @@ inline int GetTargetNumCores(const Target& target) { ICHECK(f_cpu_count) << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\""; num_cores = (*f_cpu_count)(false); - LOG(WARNING) << "Warning: Target does not have attribute \"num-cores\", falling back the " - "number of CPU cores on the local machine. The inaccuracy in number of " - "cores may lead to dramatically inferior performance. Falling back to " - "assuming " - << num_cores << " CPU core(s)"; + LOG(FATAL) << "Target does not have attribute \"num_cores\", pyhsical core number must be " + "defined! Example: Local target ...."; } return num_cores; } @@ -285,6 +283,127 @@ inline Optional ApplyTrace(const IRModule& mod, const tir::Trace& return sch; } +/*! + * \brief Get the string representation for the given schedule's IRModule. + * \param sch The given schedule. + * \return The string representation created. + */ +inline String Repr(const tir::Schedule& sch) { return tir::AsTVMScript(sch->mod()); } + +/*! + * \brief Create a sampling function that does multinomial sampling. + * \param rand_state The random state. + * \param weights The weights for multinomial sampling. + * \return The multinomial sampling function. + */ +inline std::function MakeMultinomial( + support::LinearCongruentialEngine::TRandState& rand_state, const std::vector& weights) { + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + std::uniform_real_distribution dist(0.0, sum); + auto sampler = [rand_state, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { + support::LinearCongruentialEngine rand_(&rand_state); + double p = dist(rand_); + int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; + return sampler; +} + +/*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ +inline std::function()> MakeMutatorSampler( + double p_mutate, const Map& mutator_probs, + support::LinearCongruentialEngine::TRandState& rand_state) { + CHECK(0.0 <= p_mutate && p_mutate <= 1.0) // + << "ValueError: Probability should be within [0, 1], " + << "but get `p_mutate = " << p_mutate << '\''; + std::vector> mutators; + std::vector masses; + mutators.push_back(NullOpt); + masses.push_back(1.0 - p_mutate); + double total_mass_mutator = 0.0; + for (const auto& kv : mutator_probs) { + const Mutator& mutator = kv.first; + double mass = kv.second->value; + CHECK_GE(mass, 0.0) << "ValueError: Probability of mutator '" << mutator + << "' is ill-formed: " << mass; + total_mass_mutator += mass; + mutators.push_back(kv.first); + masses.push_back(mass * p_mutate); + } + // Normalize the sum to 1.0 + if (total_mass_mutator == 0.0) { + masses[0] = 1.0; + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] = 0.0; + } + } else if (total_mass_mutator != 1.0) { + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] /= total_mass_mutator; + } + } + auto idx_sampler = MakeMultinomial(rand_state, masses); + return [idx_sampler = std::move(idx_sampler), + mutators = std::move(mutators)]() -> Optional { + int i = idx_sampler(); + return mutators[i]; + }; +} + +/*! \brief The postprocessed built of a trace */ +struct CachedTrace { + /*! \brief The trace */ + const tir::TraceNode* trace; + /*! \brief The schedule the trace creates */ + tir::Schedule sch; + /*! \brief The string representation of the schedule */ + String repr; + // Todo: Challenges in deduplication: remove unit loop / simplify pass + /*! \brief The normalized score, the higher the better */ + double score; + + static bool Compare(const CachedTrace& lhs, const CachedTrace& rhs) { + return lhs.score > rhs.score; + } +}; + +/*! + * \brief Predict the normalized score of each candidate. + * \param candidates The candidates for prediction + * \param task The search task + * \param space The search space + * \return The normalized score in the prediction + */ +inline std::vector PredictNormalizedScore(const std::vector& cached_traces, + const TuneContext& tune_context, + const CostModel& cost_model, + Array args_info) { + Array measure_inputs; + measure_inputs.reserve(cached_traces.size()); + for (const CachedTrace& cached_trace : cached_traces) { + measure_inputs.push_back(MeasureCandidate(cached_trace.sch, args_info)); + } + + std::vector scores = cost_model->Predict(tune_context, measure_inputs); + // Normalize the score + // TODO(@junrushao1994): use softmax + temperature to replace simple normalization to [0.0, +oo) + for (double& score : scores) { + score = std::max(0.0, score); + } + return scores; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 8c17934358..f1dfed1644 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -227,6 +227,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mabi") .add_attr_option("system-lib") .add_attr_option("runtime") + .add_attr_option("num_cores") .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index cb09d96614..0faad06492 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -30,12 +30,22 @@ namespace tir { /******** Schedule: Sampling ********/ /*! * \brief Sample a random integer from a given range. + * \param rand_state The pointer to schedule's random state * \param min_inclusive The minimum value of the range, inclusive. * \param max_exclusive The maximum value of the range, exclusive. * \return The random integer sampled in the given range. */ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive); +/*! + * \brief Sample k random integers from given range without replacement, i.e, no duplication. + * \param rand_state The pointer to schedule's random state + * \param n The range is defined as 0 to n-1. + * \param k The total number of samples. + * \return The randomly selected samples from the n candidates. + */ +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k); /*! * \brief Sample once category from candidates according to the probability weights. * \param rand_state The pointer to schedule's random state diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py index e57799f604..be041d543e 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -79,6 +79,7 @@ def _create_context(mod, target, rule): return ctx +# @pytest.mark.skip(reason="failing in staging branch @bohan") def test_parallel_vectorize_unroll(): expected = [ [ diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index f940d11b79..a21481d5dd 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -16,17 +16,28 @@ # under the License. """ Test Meta Schedule SearchStrategy """ # pylint: disable=missing-function-docstring -from typing import List +from typing import List, Tuple, Union, Optional import sys - +import numpy as np import pytest import tvm + +from tvm.ir import IRModule from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator.mutator import PyMutator from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace, ReplayFunc +from tvm.meta_schedule.search_strategy import ( + SearchStrategy, + ReplayTrace, + ReplayFunc, + EvolutionarySearch, + MeasureCandidate, +) +from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord +from tvm.meta_schedule.cost_model import PyCostModel from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -34,7 +45,7 @@ MATMUL_M = 32 -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking +# pylint: disable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking # fmt: off @tvm.script.ir_module @@ -53,7 +64,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# pylint: enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool: @@ -77,7 +88,7 @@ def _schedule_matmul(sch: Schedule): @pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) -def test_meta_schedule_replay_func(TestClass: SearchStrategy): +def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name num_trials_per_iter = 7 num_trials_total = 20 @@ -98,7 +109,7 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): _is_trace_equal( candidate.sch, correct_sch, - remove_decisions=(type(strategy) == ReplayTrace), + remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) strategy.notify_runner_results(runner_results) @@ -107,5 +118,143 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): assert num_trials_each_iter == [7, 7, 6] +def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name + class DummyMutator(PyMutator): + """Dummy Mutator for testing""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + class DummyDatabase(PyDatabase): + """Dummy Database for testing""" + + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + class RandomModel(PyCostModel): + """Random cost model for testing""" + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, file_location: str) -> None: + self.random_state = tuple(np.load(file_location, allow_pickle=True)) + + def save(self, file_location: str) -> None: + np.save(file_location, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + np.random.set_state(self.random_state) + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result + + num_trials_per_iter = 10 + num_trials_total = 100 + + mutator = DummyMutator() + database = DummyDatabase() + cost_model = RandomModel() + strategy = EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + num_trials_total=num_trials_total, + population=5, + init_measured_ratio=0.1, + genetic_algo_iters=3, + p_mutate=0.5, + eps_greedy=0.9, + mutator_probs={mutator: 1.0}, + database=database, + cost_model=cost_model, + ) + tune_context = TuneContext( + mod=Matmul, + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + mutators=[mutator], + target=tvm.target.Target("llvm"), + ) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + print(num_trials_each_iter) + correct_count = 6 # For each iteration except the last one + assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + [ + num_trials_total % correct_count + ] + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 006244609f09ee52b2ca763371c8afef17057a0a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 5 Nov 2021 12:24:09 -0700 Subject: [PATCH 02/14] Squashed commit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [Meta Schedule][M3c] Schedule Rules, Mutator & Postprocs (#485) [Meta Schedule][M3c] PostOrderApply (#486) Fix Post Order Apply (#490) [MetaSchedule] Relay Integration (#489) [M3c][Meta Schedule] Add Trace Correctness Test for PostOrderApply (#492) Fix replay trace. (#493) [M3c][Meta Schedule] Implement the Replay Func class. (#495) [PR] Test script for meta-schedule task extraction. Interface to load… (#494) [Meta Schedule Refactor] Get child blocks (#500) Read-at && Write-at (#497) [M3c][Meta Schedule] Measure Callbacks (#498) [Bug] Fix Infinite Loop Caused When Calling Methods Not Overrided In PyClass (#496) [MetaSchedule] Sample-Perfect-Tile (#501) [MetaSchedule] TE Workloads (#502) Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Wuwei Lin Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> --- src/tir/schedule/primitive/sampling.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 6061ca527e..f041d3a1f3 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -20,6 +20,7 @@ #include #include "../utils.h" +#include "tvm/support/random_engine.h" namespace tvm { namespace tir { From 69e38c887495c25288f3afdf7df48d736b240662 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 5 Nov 2021 23:40:03 -0700 Subject: [PATCH 03/14] [TensorIR] GetProducer, GetConsumer (#506) --- python/tvm/tir/schedule/schedule.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index c5871a53eb..9ad16f5b4a 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -471,6 +471,36 @@ def get_consumers(self, block: BlockRV) -> List[BlockRV]: """ return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore # pylint: disable=no-member + def get_producers(self, block: BlockRV) -> List[BlockRV]: + """Get the producers of a specific block + + Parameters + ---------- + block : BlockRV + The block in the query + + Returns + ------- + producers : List[BlockRV] + A list of producers of the given block + """ + return _ffi_api.ScheduleGetProducers(self, block) # type: ignore # pylint: disable=no-member + + def get_consumers(self, block: BlockRV) -> List[BlockRV]: + """Get the consumers of a specific block + + Parameters + ---------- + block : BlockRV + The block in the query + + Returns + ------- + consumers : List[BlockRV] + A list of consumers of the given block + """ + return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore # pylint: disable=no-member + ########## Schedule: Transform loops ########## def fuse(self, *loops: List[LoopRV]) -> LoopRV: """Fuse a list of consecutive loops into one. It requires: From e0c9e6652eb8cd2fc76e452ae92a4308d1085f3c Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Sat, 6 Nov 2021 22:08:08 -0400 Subject: [PATCH 04/14] [MetaScheduleRefactor] Annotate&Unannotate (#505) * annotate * annotate * lint * test * fix * fix * fix --- python/tvm/tir/schedule/schedule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 9ad16f5b4a..aba59acbcb 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -16,6 +16,7 @@ # under the License. """The TensorIR schedule class""" from typing import Dict, List, Optional, Union +from typing_extensions import Annotated from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error From f6536213422cec6fcbb833eff7acb9a4b36f503a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 10 Nov 2021 23:44:42 -0800 Subject: [PATCH 05/14] [MetaSchedule] Rewrite Cooperative-Fetching / Unbound-Block / Reduction-Block (#509) --- src/tir/schedule/primitive/sampling.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index f041d3a1f3..6061ca527e 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -20,7 +20,6 @@ #include #include "../utils.h" -#include "tvm/support/random_engine.h" namespace tvm { namespace tir { From 8d829c21b70ebda8db807fa3d8330077dc1fbd87 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 18 Nov 2021 18:52:56 -0500 Subject: [PATCH 06/14] Blockize & Tensorize (#514) * Blockize & Tensorize * Update tensor intrin * Fix blockized & Recalculate affine flags * Cleanup utils.cc * Add test cases of blockize * Re-enable affine flag checking --- include/tvm/tir/function.h | 30 +++++++++++++++++++++++++++ src/tir/schedule/analysis/analysis.cc | 12 +++++++++++ 2 files changed, 42 insertions(+) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 3efc419e4d..10a0de3042 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -263,6 +263,36 @@ class TensorIntrin : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) }; +/*! + * \brief Tensor TensorIntrin for Tensorization + */ +class TensorIntrinNode : public Object { + public: + /*! \brief The function to describe the computation. */ + PrimFunc description; + /*! \brief The intrinsic function for lower-level implement. */ + PrimFunc implementation; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("description", &description); + v->Visit("implementation", &implementation); + } + + static constexpr const char* _type_key = "tir.TensorIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); +}; + +class TensorIntrin : public ObjectRef { + public: + TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func); + + TVM_DLL static TensorIntrin Register(String name, PrimFunc desc_func, PrimFunc intrin_func); + + TVM_DLL static TensorIntrin Get(String name); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) +}; + /*! * \brief Specialize parameters of PrimFunc. * \param func The PrimFunc to be specialized. diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 5ed88da65d..e571896cd6 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1672,5 +1672,17 @@ bool HasIfThenElse(const Stmt& stmt) { return has_branch; } +bool CheckOneLine(const Stmt& s) { + bool legal = true, meet_block = false; + PostOrderVisit(s, [&legal, &meet_block](const ObjectRef& obj) { + if (obj->IsInstance() && !meet_block) { + legal = false; + } else if (obj->IsInstance()) { + meet_block = true; + } + }); + return legal; +} + } // namespace tir } // namespace tvm From e171f01c592358a2465c0c5757610d1aab1433a9 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 22 Nov 2021 23:45:14 -0800 Subject: [PATCH 07/14] Checkpoint. Fix cost model comment. Finish evolutionary seaarch. Remove extra code. Fix compile. Add comments. Add python part. Ad test. Update other files & comments. Fix random seed bug. Minor fix. Fix num-cores. Add docs. Check point. Add max_fail_cnt. Minor fix. Minor fix. Segfault. Fix pointers to trace. Test fix. Remove measure callbacks. Refactor a bit. Split function. Adjust variable name. Minor fixes. Add mutator probs to TuneContext. Add token. Fix loops. Remove include. Add has workload for database. Add check. Add concurrent bitmask. --- include/tvm/meta_schedule/database.h | 24 +- include/tvm/meta_schedule/search_strategy.h | 37 +- include/tvm/meta_schedule/tune_context.h | 10 +- include/tvm/support/random_engine.h | 1 + include/tvm/tir/function.h | 30 - .../tvm/meta_schedule/cost_model/xgb_model.py | 2 +- python/tvm/meta_schedule/database/database.py | 20 + .../search_strategy/evolutionary_search.py | 29 +- python/tvm/meta_schedule/tune_context.py | 39 +- python/tvm/tir/schedule/schedule.py | 31 - src/meta_schedule/database/database.cc | 6 +- src/meta_schedule/database/json_database.cc | 4 + .../search_strategy/evolutionary_search.cc | 700 +++++++++++------- src/meta_schedule/tune_context.cc | 14 +- src/meta_schedule/utils.h | 128 +--- src/target/target_kind.cc | 3 +- src/tir/schedule/analysis/analysis.cc | 12 - src/tir/schedule/concrete_schedule.cc | 2 +- src/tir/schedule/primitive.h | 10 + src/tir/schedule/primitive/sampling.cc | 23 + src/tir/schedule/traced_schedule.cc | 2 +- ...schedule_rule_parallel_vectorize_unroll.py | 1 - .../test_meta_schedule_search_strategy.py | 26 +- .../test_meta_schedule_task_scheduler.py | 6 + 24 files changed, 598 insertions(+), 562 deletions(-) diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 60c6898f00..307ec309c0 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -155,6 +155,12 @@ class DatabaseNode : public runtime::Object { public: /*! \brief Default destructor */ virtual ~DatabaseNode() = default; + /*! + * \brief Check if the database has the given workload. + * \param mod The IRModule to be searched for. + * \return Whether the database has the given workload. + */ + virtual bool HasWorkload(const IRModule& mod) = 0; /*! * \brief Look up or add workload to the database if missing. * \param mod The IRModule to be searched for or added. @@ -186,6 +192,12 @@ class DatabaseNode : public runtime::Object { /*! \brief The database with customized methods on the python-side. */ class PyDatabaseNode : public DatabaseNode { public: + /*! + * \brief The function type of `HasWorkload` method. + * \param mod The IRModule to be searched for. + * \return Whether the database has the given workload. + */ + using FHasWorkload = runtime::TypedPackedFunc; /*! * \brief The function type of `CommitWorkload` method. * \param mod The IRModule to be searched for or added. @@ -210,6 +222,8 @@ class PyDatabaseNode : public DatabaseNode { */ using FSize = runtime::TypedPackedFunc; + /*! \brief The packed function to the `HasWorkload` function. */ + FHasWorkload f_has_workload; /*! \brief The packed function to the `CommitWorkload` function. */ FCommitWorkload f_commit_workload; /*! \brief The packed function to the `CommitTuningRecord` function. */ @@ -224,12 +238,18 @@ class PyDatabaseNode : public DatabaseNode { // so it cannot be accessible on the python side. If there is such need from the future, // we can then add corresponding accessor methods to help access on python. // + // `f_has_workload` is not visited // `f_commit_workload` is not visited // `f_commit_tuning_record` is not visited // `f_get_top_k` is not visited // `f_size` is not visited } + bool HasWorkload(const IRModule& mod) final { + ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!"; + return f_has_workload(mod); + } + Workload CommitWorkload(const IRModule& mod) final { ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!"; return f_commit_workload(mod); @@ -271,13 +291,15 @@ class Database : public runtime::ObjectRef { bool allow_missing); /*! * \brief Create a database with customized methods on the python-side. + * \param f_has_workload The packed function of `HasWorkload`. * \param f_commit_workload The packed function of `CommitWorkload`. * \param f_commit_tuning_record The packed function of `CommitTuningRecord`. * \param f_get_top_k The packed function of `GetTopK`. * \param f_size The packed function of `Size`. * \return The created database. */ - TVM_DLL static Database PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload, + TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, + PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size); diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index f2798bf8f9..e645f15ef1 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -20,10 +20,7 @@ #define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ #include -#include -#include #include -#include #include namespace tvm { @@ -32,6 +29,7 @@ namespace meta_schedule { // Forward declaration class TuneContext; class CostModel; +class Database; /*! \brief The schedule (with input shapes) to be measured. */ class MeasureCandidateNode : public runtime::Object { @@ -258,15 +256,30 @@ class SearchStrategy : public runtime::ObjectRef { */ TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); - TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, // - int num_trials_total, // - int population, // - double init_measured_ratio, // - int genetic_algo_iters, // - double p_mutate, // - double eps_greedy, // - Map mutator_probs, // - Database database, // + /*! + * \brief Constructor of evolutionary search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for evolutionary search. + * \param population The initial sample population. + * \param max_replay_fail_cnt The maximum number to fail trace replaying. + * \param init_measured_ratio The ratio of measures samples in initial population. + * \param genetic_algo_iters The iterations to run the genetic algorithm. + * \param max_evolve_fail_cnt The maximum number to try evolving the given trace. + * \param p_mutate The probability of mutation. + * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score. + * \param database The database to use. + * \param cost_model The cost model to use. + */ + TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population, // + int max_replay_fail_cnt, // + double init_measured_ratio, // + int genetic_algo_iters, // + int max_evolve_fail_cnt, // + double p_mutate, // + double eps_greedy, // + Database database, // CostModel cost_model); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index fbad1776d0..d15b3b6123 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -48,8 +48,8 @@ class TuneContextNode : public runtime::Object { Array sch_rules; /*! \brief The postprocessors. */ Array postprocs; - /*! \brief The mutators. */ - Array mutators; + /*! \brief The probability of using certain mutator. */ + Optional> mutator_probs; /*! \brief The name of the tuning task. */ String task_name; /*! \brief The random state. */ @@ -73,7 +73,7 @@ class TuneContextNode : public runtime::Object { v->Visit("search_strategy", &search_strategy); v->Visit("sch_rules", &sch_rules); v->Visit("postprocs", &postprocs); - v->Visit("mutators", &mutators); + v->Visit("mutator_probs", &mutator_probs); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); @@ -104,7 +104,7 @@ class TuneContext : public runtime::ObjectRef { * \param search_strategy The search strategy. * \param sch_rules The schedule rules. * \param postprocs The postprocessors. - * \param mutators The mutators. + * \param mutator_probs The probability of using certain mutator. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -115,7 +115,7 @@ class TuneContext : public runtime::ObjectRef { Optional search_strategy, // Optional> sch_rules, // Optional> postprocs, // - Optional> mutators, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index fcd2326050..560ed2dbad 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -93,6 +93,7 @@ class LinearCongruentialEngine { * \param rand_state The random state given in result_type. */ void Seed(TRandState rand_state = 1) { + ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!"; rand_state %= modulus; // Make sure the seed is within the range of modulus. if (rand_state == 0) rand_state = 1; // Avoid getting all 0 given the current parameter set. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 10a0de3042..3efc419e4d 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -263,36 +263,6 @@ class TensorIntrin : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) }; -/*! - * \brief Tensor TensorIntrin for Tensorization - */ -class TensorIntrinNode : public Object { - public: - /*! \brief The function to describe the computation. */ - PrimFunc description; - /*! \brief The intrinsic function for lower-level implement. */ - PrimFunc implementation; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("description", &description); - v->Visit("implementation", &implementation); - } - - static constexpr const char* _type_key = "tir.TensorIntrin"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); -}; - -class TensorIntrin : public ObjectRef { - public: - TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func); - - TVM_DLL static TensorIntrin Register(String name, PrimFunc desc_func, PrimFunc intrin_func); - - TVM_DLL static TensorIntrin Get(String name); - - TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) -}; - /*! * \brief Specialize parameters of PrimFunc. * \param func The PrimFunc to be specialized. diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 441cf1cbbc..5cc36db8a9 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -294,7 +294,7 @@ def __init__( # model-related if config.nthread is None: # use physical core number - config._replace(nthread=cpu_count(logical=False)) + config = config._replace(nthread=cpu_count(logical=False)) self.config = config # serialization-related if path is not None: diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index fd746e640c..b5ca5740b2 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -147,6 +147,21 @@ def from_json(json_obj: Any, workload: Workload) -> "TuningRecord": class Database(Object): """The abstract database interface.""" + def has_workload(self, mod: IRModule) -> bool: + """Check if the database has the given workload. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + + Returns + ------- + result : bool + Whether the database has the given workload. + """ + return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore # pylint: disable=no-member + def commit_workload(self, mod: IRModule) -> Workload: """Commit a workload to the database if missing. @@ -207,6 +222,10 @@ class PyDatabase(Database): def __init__(self): """Constructor.""" + @check_override(self.__class__, Database) + def f_has_workload(mod: IRModule) -> bool: + return self.has_workload(mod) + @check_override(self.__class__, Database) def f_commit_workload(mod: IRModule) -> Workload: return self.commit_workload(mod) @@ -225,6 +244,7 @@ def f_size() -> int: self.__init_handle_by_constructor__( _ffi_api.DatabasePyDatabase, # type: ignore # pylint: disable=no-member + f_has_workload, f_commit_workload, f_commit_tuning_record, f_get_top_k, diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index 0b5538ed40..67bb2dd85a 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, Dict from tvm._ffi import register_object -from ...tir import FloatImm from .search_strategy import SearchStrategy from ..mutator import Mutator @@ -45,16 +44,18 @@ class EvolutionarySearch(SearchStrategy): Total number of trials. population : int The initial population of traces from measured samples and randomly generated samples. + max_replay_fail_cnt : int + The maximum number to fail trace replaying. init_measured_ratio : int The ratio of measured samples in the initial population. genetic_algo_iters : int The number of iterations for genetic algorithm. + max_evolve_fail_cnt : int + The maximum number to retry mutation. p_mutate : float The probability of mutation. eps_greedy : float The ratio of greedy selected samples in the final picks. - mutator_probs: Dict[Mutator, FloatImm] - The probability contribution of all mutators. database : Database The database used in the search. cost_model : CostModel @@ -66,36 +67,40 @@ class EvolutionarySearch(SearchStrategy): population: int init_measured_ratio: int genetic_algo_iters: int + max_replay_fail_cnt: int + max_evolve_fail_cnt: int p_mutate: float eps_greedy: float - mutator_probs: Dict[Mutator, FloatImm] database: Database cost_model: "CostModel" def __init__( self, + *, num_trials_per_iter: int, num_trials_total: int, - population: int, - init_measured_ratio: float, - genetic_algo_iters: int, - p_mutate: float, - eps_greedy: float, - mutator_probs: Dict[Mutator, FloatImm], database: Database, cost_model: "CostModel", + population: int = 2048, + max_replay_fail_cnt: int = 64, + init_measured_ratio: float = 0.2, + genetic_algo_iters: int = 10, + max_evolve_fail_cnt: int = 10, + p_mutate: float = 0.85, + eps_greedy: float = 0.25, ): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.SearchStrategyEvolutionarySearch, # pylint: disable=no-member + _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member num_trials_per_iter, num_trials_total, population, + max_replay_fail_cnt, init_measured_ratio, genetic_algo_iters, + max_evolve_fail_cnt, p_mutate, eps_greedy, - mutator_probs, database, cost_model, ) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 8ba5e8727d..196b1c16b6 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,7 +16,7 @@ # under the License. """Meta Schedule tuning context.""" -from typing import List, Optional, TYPE_CHECKING +from typing import Optional, List, Dict, TYPE_CHECKING from tvm import IRModule from tvm._ffi import register_object @@ -58,8 +58,8 @@ class TuneContext(Object): The schedule rules. postprocs: Optional[List[Postproc"]] = None, The postprocessors. - mutators: Optional[List[Mutator]] = None, - The mutators. + mutator_probs: Optional[Dict[Mutator, float]] + Mutators and their probability mass. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -82,7 +82,7 @@ class TuneContext(Object): search_strategy: Optional["SearchStrategy"] sch_rules: List["ScheduleRule"] postprocs: List["Postproc"] - mutators: List["Mutator"] + mutator_probs: Optional[Dict["Mutator", float]] task_name: str rand_state: int num_threads: int @@ -90,42 +90,17 @@ class TuneContext(Object): def __init__( self, mod: Optional[IRModule] = None, + *, target: Optional[Target] = None, space_generator: Optional["SpaceGenerator"] = None, search_strategy: Optional["SearchStrategy"] = None, sch_rules: Optional[List["ScheduleRule"]] = None, postprocs: Optional[List["Postproc"]] = None, - mutators: Optional[List["Mutator"]] = None, + mutator_probs: Optional[Dict["Mutator", float]] = None, task_name: str = "main", rand_state: int = -1, num_threads: Optional[int] = None, ): - """Constructor. - - Parameters - ---------- - mod : Optional[IRModule] = None - The workload to be optimized. - target : Optional[Target] = None - The target to be optimized for. - space_generator : Optional[SpaceGenerator] = None - The design space generator. - search_strategy : Optional[SearchStrategy] = None - The search strategy. - sch_rules : List[ScheduleRule] = [] - The schedule rules. - postprocs : List[Postproc] = [] - The postprocessors. - mutators : List[Mutator] = [] - The mutators. - task_name : str = "main" - The name of the tuning task. - rand_state : int = -1 - The random state. - Need to be in integer in [1, 2^31-1], -1 means using random number. - num_threads : Optional[int] = None - The number of threads to be used, None means using the logical cpu count. - """ if isinstance(mod, PrimFunc): mod = IRModule.from_expr(mod) if num_threads is None: @@ -139,7 +114,7 @@ def __init__( search_strategy, sch_rules, postprocs, - mutators, + mutator_probs, task_name, rand_state, num_threads, diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index aba59acbcb..c5871a53eb 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -16,7 +16,6 @@ # under the License. """The TensorIR schedule class""" from typing import Dict, List, Optional, Union -from typing_extensions import Annotated from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error @@ -472,36 +471,6 @@ def get_consumers(self, block: BlockRV) -> List[BlockRV]: """ return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore # pylint: disable=no-member - def get_producers(self, block: BlockRV) -> List[BlockRV]: - """Get the producers of a specific block - - Parameters - ---------- - block : BlockRV - The block in the query - - Returns - ------- - producers : List[BlockRV] - A list of producers of the given block - """ - return _ffi_api.ScheduleGetProducers(self, block) # type: ignore # pylint: disable=no-member - - def get_consumers(self, block: BlockRV) -> List[BlockRV]: - """Get the consumers of a specific block - - Parameters - ---------- - block : BlockRV - The block in the query - - Returns - ------- - consumers : List[BlockRV] - A list of consumers of the given block - """ - return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore # pylint: disable=no-member - ########## Schedule: Transform loops ########## def fuse(self, *loops: List[LoopRV]) -> LoopRV: """Fuse a list of consecutive loops into one. It requires: diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index e67b3d1ab9..fc7cc74de5 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -135,10 +135,12 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w /******** PyDatabase ********/ -Database Database::PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload, +Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, + PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size) { ObjectPtr n = make_object(); + n->f_has_workload = f_has_workload; n->f_commit_workload = f_commit_workload; n->f_commit_tuning_record = f_commit_tuning_record; n->f_get_top_k = f_get_top_k; @@ -166,6 +168,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") .set_body_method(&TuningRecordNode::AsJSON); TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload") + .set_body_method(&DatabaseNode::HasWorkload); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") .set_body_method(&DatabaseNode::CommitWorkload); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 3efb72e2fa..2e76940fee 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -69,6 +69,10 @@ class JSONDatabaseNode : public DatabaseNode { TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); public: + bool HasWorkload(const IRModule& mod) { + return workloads2idx_.find(Workload(mod, tvm::StructuralHash()(mod))) != workloads2idx_.end(); + } + Workload CommitWorkload(const IRModule& mod) { // Try to insert `mod` into `workloads_` decltype(this->workloads2idx_)::iterator it; diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index fd65ab9586..ab04143291 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -19,11 +19,34 @@ #include "../utils.h" +#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ + CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ + << "but get `" << #p << " = " << (p) << '\''; + namespace tvm { namespace meta_schedule { /**************** Data Structure ****************/ +/*! \brief The postprocessed built of a trace */ +struct CachedTrace { + /*! \brief The type of structural hash */ + using THashCode = size_t; + /*! \brief The trace */ + tir::Trace trace{nullptr}; + /*! \brief The schedule the trace creates */ + tir::Schedule sch{nullptr}; + /*! \brief The structural hash of the schedule */ + THashCode shash; // todo(@zxybazh): deduplication + /*! \brief The normalized score, the higher the better */ + double score; + + inline bool defined() const { return trace.defined(); } + friend bool operator<(const CachedTrace& lhs, const CachedTrace& rhs) { + return lhs.score > rhs.score; + } +}; + /*! * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. * \note It maintains a min heap in terms of `CachedTrace::score`. Therefore, when @@ -31,11 +54,6 @@ namespace meta_schedule { * As time goes, the elements in the heap are going to be larger. */ class SizedHeap { - /*! \brief The comparator class, used by `std::push_heap` and `std::pop_heap` */ - struct Comparator { - bool operator()(const CachedTrace& a, const CachedTrace& b) const { return a.score > b.score; } - }; - public: /*! * \brief Constructor @@ -48,19 +66,19 @@ class SizedHeap { * \param item The item to be pushed */ void Push(const CachedTrace& item) { - if (!in_heap.insert(item.repr).second) { + if (!in_heap.insert(item.shash).second) { return; } int size = heap.size(); if (size < size_limit) { // Heap is not full, just push heap.emplace_back(item); - std::push_heap(heap.begin(), heap.end(), Comparator()); - } else if (Comparator()(item, heap.front())) { + std::push_heap(heap.begin(), heap.end()); + } else if (item < heap.front()) { // if the item is better than the worst one in the heap, we can safely kick it out - std::pop_heap(heap.begin(), heap.end(), Comparator()); + std::pop_heap(heap.begin(), heap.end()); heap.back() = item; - std::push_heap(heap.begin(), heap.end(), Comparator()); + std::push_heap(heap.begin(), heap.end()); } // Otherwise, the item is worse than any other element in the heap } @@ -70,9 +88,167 @@ class SizedHeap { /*! \brief The heap, the worse the topper */ std::vector heap; /*! \brief The traces that are in the heap */ - std::unordered_set in_heap; + std::unordered_set in_heap; +}; + +struct PerThreadData { + IRModule mod; + TRandState rand_state; }; +struct PerThreadDataEx { + IRModule mod; + TRandState rand_state; + std::function trace_sampler; + std::function()> mutator_sampler; + explicit PerThreadDataEx() {} + explicit PerThreadDataEx(PerThreadData* data, const std::vector& scores, double p_mutate, + const Map& mutator_probs); +}; + +struct ConcurrentBitmask { + /*! The bit width. */ + const static int bitw = 64; + /*! + * \brief Constructor + * \param size The size of the concurrent bitmask. + */ + explicit ConcurrentBitmask(int size) : size(size) { + bitmask.assign((size + bitw - 1) / bitw, 0); + std::vector list((size + bitw - 1) / bitw); + trace_used_mutex.swap(list); + } + /*! \brief The size of the concurrent bitmask. */ + int size; + /*! \brief The bitmasks. */ + std::vector bitmask; + /*! \brief The mutexes, one per bitw(64 here) bitmasks. */ + std::vector trace_used_mutex; + /*! + * \brief Query and mark the given index if not visted before. + * \param x The index to concurrently check if used. If not, mark as used. + * \return Whether the index has been used before. + */ + bool query_and_mark(int x) { + if (x < 0 || x >= size) return false; + std::unique_lock lock(trace_used_mutex[x / bitw]); + if (bitmask[x / bitw] & ((uint64_t)1 << (x % bitw))) { + return false; + } else { + bitmask[x / bitw] |= (uint64_t)1 << (x % bitw); + return true; + } + } +}; + +/**************** Util Functions ****************/ + +/*! + * \brief Assemble measure candidates from the given candidate traces. + * \param traces The picked candidate traces. + * \return The assembled measure candidates. + */ +inline Array AssembleCandidates(const std::vector& picks, + const Array& args_info) { + Array measure_inputs; + measure_inputs.reserve(picks.size()); + for (const CachedTrace& pick : picks) { + measure_inputs.push_back(MeasureCandidate(pick.sch, args_info)); + } + return measure_inputs; +} + +/*! + * \brief Predict the normalized score of each candidate. + * \param candidates The candidates for prediction + * \param task The search task + * \param space The search space + * \return The normalized score in the prediction + */ +inline std::vector PredictNormalizedScore(const std::vector& cached_traces, + const TuneContext& tune_context, + const CostModel& cost_model, + const Array& args_info) { + ICHECK(cached_traces.size() > 0) + << "Candidates given for score prediction can not be empty list!"; + std::vector scores = + cost_model->Predict(tune_context, AssembleCandidates(cached_traces, args_info)); + // Normalize the score + // TODO(@junrushao1994): use softmax + temperature to replace simple normalization to [0.0, +oo) + for (double& score : scores) { + score = std::max(0.0, score); + } + return scores; +} + +/*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ +inline std::function()> MakeMutatorSampler( + double p_mutate, const Map& mutator_probs, + support::LinearCongruentialEngine::TRandState* rand_state) { + std::vector> mutators; + std::vector masses; + mutators.push_back(NullOpt); + masses.push_back(1.0 - p_mutate); + double total_mass_mutator = 0.0; + if (p_mutate > 0) { + for (const auto& kv : mutator_probs) { + const Mutator& mutator = kv.first; + double mass = kv.second->value; + CHECK_GE(mass, 0.0) << "ValueError: Probability of mutator '" << mutator + << "' is ill-formed: " << mass; + total_mass_mutator += mass; + mutators.push_back(kv.first); + masses.push_back(mass * p_mutate); + } + } + // Normalize the sum to 1.0 + if (total_mass_mutator == 0.0) { + masses[0] = 1.0; + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] = 0.0; + } + } else if (total_mass_mutator != 1.0) { + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] /= total_mass_mutator; + } + } + return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), + mutators = std::move(mutators)]() -> Optional { + int i = idx_sampler(); + return mutators[i]; + }; +} + +/*! + * \brief Get the structural hash for the given schedule's IRModule. + * \param sch The given schedule. + * \return The structural hash of the schedule's IRModule. + */ +inline CachedTrace::THashCode StructuralHash(const tir::Schedule& sch) { + return tvm::StructuralHash()(sch->mod()); +} + +/*! + * \brief Constructor to create a more extensive per-thread data pack from original per-thread one. + * \param data The pointer to original per-thread data pack. + * \param scores The predicted score for the given samples. + * \param p_mutate The probability of mutation. + * \param mutator_probs The probability of each mutator as a dict. + */ +PerThreadDataEx::PerThreadDataEx(PerThreadData* data, const std::vector& scores, + double p_mutate, const Map& mutator_probs) { + rand_state = ForkSeed(&data->rand_state); + mod = data->mod; + trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); + mutator_sampler = MakeMutatorSampler(p_mutate, mutator_probs, &rand_state); +} + +/**************** Evolutionary Search ****************/ + /*! * \brief A search strategy that generates measure candidates using evolutionary search. * \note The algorithm: @@ -98,145 +274,136 @@ class EvolutionarySearchNode : public SearchStrategyNode { struct State { /*! \brief The search strategy itself */ EvolutionarySearchNode* self; - /*! \brief The design spaces. */ - Array design_spaces; + /*! \brief The design spaces. Decisions are not used so traces only. */ + Array design_spaces; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; - explicit State(EvolutionarySearchNode* self, Array design_spaces) + explicit State(EvolutionarySearchNode* self, Array design_spaces) : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + /*! + * \brief Pick up best candidates from database. + * \param num The number of traces to produce. + * \return The picked best candidates. + */ + inline std::vector PickBestFromDatabase(int num); /*! * \brief Sample the initial population from previous measured results and randomly generated * traces via trace replaying. + * \param num The number of traces to produce. * \return The initial population of traces sampled. */ - inline Array SampleInitPopulation(); + inline std::vector SampleInitPopulation(int num); + /*! + * \brief Pick final candidates from the given initial population and bests of evolved ones. + * \param measured Measured samples from database. + * \param unmeasured Unmeasured samples from replaying traces from design space. + * \return The merged results, excluding undefined samples. + */ + inline std::vector PruneAndMergeSamples( + const std::vector& measured, const std::vector& unmeasured); /*! * \brief Evolve the initial population using mutators and samplers. * \param inits The initial population of traces sampled. + * \param num The number of traces to produce. * \return The evolved traces from initial population. */ - inline Array EvolveWithCostModel(const Array& inits); + inline std::vector EvolveWithCostModel(const std::vector& inits, + int num); /*! * \brief Pick final candidates from the given initial population and bests of evolved ones. * \param inits The initial population of traces sampled. * \param bests The best candidates predicted from evolved traces. + * \param num The number of traces to produce. * \return The final picked candidates with a ratio of both. */ - inline Array PickWithEpsGreedy(const Array& inits, - const Array& bests); - /*! - * \brief Assemble measure candidates from the given candidate traces. - * \param traces The picked candidate traces. - * \return The assembled measure candidates. - */ - inline Array AssembleCandidates(const Array& picks); + inline std::vector PickWithEpsGreedy(const std::vector& inits, + const std::vector& bests, + int num); inline Optional> GenerateMeasureCandidates(); inline void NotifyRunnerResults(const Array& results); }; - /*! \brief The number of trials per iteration. */ - int num_trials_per_iter; - /*! \brief The number of total trials. */ - int num_trials_total; - /*! \brief THe population size in the evolutionary search.*/ - int population; - - /*! \brief The target for the workload. */ - Target target_{nullptr}; /*! \brief The tuning context of the evolutionary search strategy. */ TuneContext tune_context_{nullptr}; - /*! \brief The mutators to be used. */ - Array mutators_{nullptr}; - /*! \brief The module to be tuned. */ - Array mod_{nullptr}; + /*! \brief The target for the workload. */ + Target target_{nullptr}; /*! \brief The metadata of the function arguments. */ Array args_info_{nullptr}; - /*! \brief The number of threads to use. -1 means using logical cpu number. */ - int num_threads_ = -1; - /*! \brief The random state. -1 means using random number. */ - TRandState rand_state_ = -1; + /*! \brief A Database for selecting useful candidates. */ + Database database_{nullptr}; + /*! \brief A cost model helping to explore the search space */ + CostModel cost_model_{nullptr}; + /*! \brief The postprocessors. */ + Array postprocs_{nullptr}; + /*! \brief Mutators and their probability mass */ + Map mutator_probs_{nullptr}; + /*! \brief The number of threads to use. To be initialized with TuneContext. */ + int num_threads_; + /*! \brief The random state. To be initialized with TuneContext. */ + TRandState rand_state_; + /*! \brief Pre thread data including module to be tuned and random state. */ + std::vector per_thread_data_; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; + /*! \brief The token registered for the given workload in database. */ + Workload token_{nullptr}; + + /*** Configuration: global ***/ + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; /*** Configuration: the initial population ***/ + /*! \brief The population size in the evolutionary search. */ + int population; /*! \brief The ratio of measured states used in the initial population */ double init_measured_ratio; + /*! \brief The maximum number to fail trace replaying. */ + int max_replay_fail_cnt; /*** Configuration: evolution ***/ /*! \brief The number of iterations performed by generic algorithm. */ int genetic_algo_iters; + /*! \brief The maximum number to try evolving the given trace. */ + int max_evolve_fail_cnt; /*! \brief The probability to perform mutation */ double p_mutate; - /*! \brief Mutators and their probability mass */ - Map mutator_probs{nullptr}; - /*! \brief A Database for selecting useful candidates. */ - Database database{nullptr}; - /*! \brief A cost model helping to explore the search space */ - CostModel cost_model{nullptr}; - /*! \brief The batch of measure candidates generated for measurement. */ - Array candidates{nullptr}; /*** Configuration: pick states for measurement ***/ /*! \brief The ratio of measurements to use randomly sampled states. */ double eps_greedy; - /*! - * Helpers - * Note that the use of trace cache could be multi-threaded. - */ - mutable std::unordered_map trace_cache_; - mutable std::mutex trace_cache_mutex_; - void VisitAttrs(tvm::AttrVisitor* v) { // `tune_context_` is not visited // `target_` is not visited - // `mod_` is not visited // `args_info_` is not visited - // `num_threads_` is not visited + // `database` is not visited + // `cost_model` is not visited + // `postprocs` is not visited + // `mutator_probs_` is not visited + // `num_threads` is not visited // `rand_state_` is not visited + // `per_thread_data_` is not visited // `state_` is not visited /*** Configuration: global ***/ v->Visit("num_trials_total", &num_trials_total); v->Visit("num_trials_per_iter", &num_trials_per_iter); - v->Visit("population", &population); /*** Configuration: the initial population ***/ + v->Visit("population", &population); v->Visit("init_measured_ratio", &init_measured_ratio); + v->Visit("max_replay_fail_cnt", &max_replay_fail_cnt); /*** Configuration: evolution ***/ v->Visit("genetic_algo_iters", &genetic_algo_iters); + v->Visit("max_evolve_fail_cnt", &max_evolve_fail_cnt); v->Visit("p_mutate", &p_mutate); - v->Visit("mutator_probs", &mutator_probs); - v->Visit("cost_model", &cost_model); /*** Configuration: pick states for measurement ***/ v->Visit("eps_greedy", &eps_greedy); - /*** Helpers ***/ - // Not visited: `trace_cache_` - // Not visited: `trace_cache_mutex_` - } - - /*! - * \brief Add the cached trace into the trace_cache_ - * \param cached_trace The cached_trace to be added - */ - void _AddCachedTrace(const CachedTrace& cached_trace) const { - // Todo(@zxybazh): Avoid redundent traces - std::unique_lock lock(this->trace_cache_mutex_); - trace_cache_.emplace(GetRef(cached_trace.trace), cached_trace); - } - - /*! - * \brief Retrieve the cached trace given the trace - * \param trace The trace to be retrieved - * \return The cached trace - */ - CachedTrace _GetCachedTrace(const tir::Trace& trace) const { - auto iter = trace_cache_.find(trace); - ICHECK(iter != trace_cache_.end()); - return iter->second; } static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; @@ -245,27 +412,36 @@ class EvolutionarySearchNode : public SearchStrategyNode { void InitializeWithTuneContext(const TuneContext& tune_context) final { CHECK(tune_context.defined()) << "TuneContext must be defined!"; CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; - CHECK(tune_context->mutators.defined()) << "Mutators must be defined!"; + CHECK(p_mutate == 0 || tune_context->mutator_probs.defined()) + << "Mutators and their probabilities must be defined given mutation probability is not 0!"; CHECK(tune_context->target.defined()) << "Target must be defined!"; + this->tune_context_ = tune_context; this->target_ = tune_context->target.value(); - this->mutators_ = tune_context->mutators.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + if (p_mutate > 0) this->mutator_probs_ = tune_context->mutator_probs.value(); + this->postprocs_ = tune_context->postprocs; this->num_threads_ = tune_context->num_threads; - - this->mod_.reserve(this->num_threads_); + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->token_ = this->database_->CommitWorkload(tune_context->mod.value()); + this->per_thread_data_.reserve(this->num_threads_); for (int i = 0; i < this->num_threads_; i++) { - this->mod_.push_back(DeepCopyIRModule(tune_context->mod.value())); + this->per_thread_data_.push_back( + PerThreadData{DeepCopyIRModule(tune_context->mod.value()), ForkSeed(&this->rand_state_)}); } - - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); - this->rand_state_ = ForkSeed(&tune_context->rand_state); this->state_.reset(); } void PreTuning(const Array& design_spaces) final { ICHECK(!design_spaces.empty()); ICHECK(this->state_ == nullptr); - this->state_ = std::make_unique(this, design_spaces); + // Change to traces + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const tir::Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); } void PostTuning() final { @@ -284,312 +460,276 @@ class EvolutionarySearchNode : public SearchStrategyNode { } }; -inline Array EvolutionarySearchNode::State::SampleInitPopulation() { - self->trace_cache_.clear(); - std::vector results; - results.reserve(self->population); - // Threading RNG - std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); - // Pick measured states - int num_measured = self->population * self->init_measured_ratio; - for (TuningRecord record : - self->database->GetTopK(self->database->CommitWorkload(self->mod_[0]), num_measured)) { - results.push_back(record->trace); +inline std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { + std::vector measured_traces; + measured_traces.reserve(num); + std::vector results(num, CachedTrace()); + + for (TuningRecord record : self->database_->GetTopK(self->token_, num)) { + measured_traces.push_back(record->trace); } - auto f_proc_measured = [this, &results, &per_thread_rand_state](int thread_id, - int trace_id) -> void { - TRandState& rand_state = per_thread_rand_state[trace_id]; - const tir::Trace& trace = results[trace_id]; + auto f_proc_measured = [this, &measured_traces, &results](int thread_id, int trace_id) -> void { + TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; + const IRModule& mod = self->per_thread_data_[thread_id].mod; + tir::Trace trace = measured_traces[trace_id]; if (Optional opt_sch = - meta_schedule::ReplayTrace(trace, self->mod_[trace_id], &rand_state)) { + meta_schedule::ApplyTrace(mod, trace, &rand_state, self->postprocs_)) { tir::Schedule sch = opt_sch.value(); - self->_AddCachedTrace(CachedTrace{trace.get(), sch, Repr(sch), -1.0}); + results[trace_id] = CachedTrace{trace, sch, StructuralHash(sch), -1.0}; } else { LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; throw; } }; - support::parallel_for_dynamic(0, results.size(), self->num_threads_, f_proc_measured); + support::parallel_for_dynamic(0, measured_traces.size(), self->num_threads_, f_proc_measured); + return results; +} +inline std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { // Pick unmeasured states - std::atomic tot_fail_ct(0); - std::atomic success_ct(0); - auto f_proc_unmeasured = [this, &results, &per_thread_rand_state, &tot_fail_ct, &success_ct]( - int thread_id, int trace_id) -> void { - TRandState& rand_state = per_thread_rand_state[trace_id]; - for (;;) { + std::vector results; + auto f_proc_unmeasured = [this, &results, &num](int thread_id, int trace_id) -> void { + TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; + const IRModule& mod = self->per_thread_data_[thread_id].mod; + CachedTrace& result = results[trace_id]; + for (int fail_ct = 0; fail_ct <= self->max_replay_fail_cnt; fail_ct++) { int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); - tir::Trace trace = design_spaces[design_space_index]->trace().value(); - Map decisions; - try { - if (Optional opt_sch = - meta_schedule::ReplayTrace(trace, self->mod_[trace_id], &rand_state)) { - tir::Schedule sch = opt_sch.value(); - tir::Trace old_trace = sch->trace().value(); - tir::Trace trace(old_trace->insts, old_trace->decisions); - self->_AddCachedTrace(CachedTrace{trace.get(), sch, Repr(sch), -1.0}); - results[trace_id] = std::move(trace); - success_ct++; - break; - } else { - tot_fail_ct++; - } - } catch (const dmlc::Error& e) { - tot_fail_ct++; - } - if (success_ct > 64) { // Todo(@junru): Why 64? Add to constructor. + tir::Trace trace = design_spaces[design_space_index]; + if (Optional opt_sch = + // replay trace, i.e., remove decisions + ApplyTrace(mod, tir::Trace(trace->insts, {}), &rand_state, self->postprocs_)) { + tir::Schedule sch = opt_sch.value(); + tir::Trace trace = sch->trace().value(); + result = CachedTrace{trace, sch, StructuralHash(sch), -1.0}; break; } } + if (!result.defined()) { + LOG(FATAL) << "Sample-Init-Population failed over the maximum limit!"; + } }; - num_measured = results.size(); - results.resize(self->population, tir::Trace(nullptr)); - support::parallel_for_dynamic(num_measured, self->population, self->num_threads_, - f_proc_unmeasured); - std::vector pruned_results; - for (const tir::Trace& result : results) { - if (result.defined()) { - pruned_results.push_back(result); + results.resize(num, CachedTrace()); + support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); + return results; +} + +inline std::vector EvolutionarySearchNode::State::PruneAndMergeSamples( + const std::vector& measured, const std::vector& unmeasured) { + std::vector pruned; + pruned.reserve(measured.size() + unmeasured.size()); + for (const CachedTrace& entry : measured) { + if (entry.defined()) { + pruned.push_back(entry); } } - // LOG(INFO) << "fail count: " << tot_fail_ct; - return pruned_results; + for (const CachedTrace& entry : unmeasured) { + if (entry.defined()) { + pruned.push_back(entry); + } + } + return pruned; } -Array EvolutionarySearchNode::State::EvolveWithCostModel( - const Array& inits) { +std::vector EvolutionarySearchNode::State::EvolveWithCostModel( + const std::vector& inits, int num) { + std::vector per_thread_data_ex(self->num_threads_); // The heap to record best schedule, we do not consider schedules that are already measured // Also we use `in_heap` to make sure items in the heap are de-duplicated - SizedHeap heap(self->num_trials_per_iter); - // Threading RNG - std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); - std::vector> thread_trace_samplers(self->num_threads_); - std::vector()>> thread_mutator_samplers(self->num_threads_); - std::vector trace_used; - std::mutex trace_used_mutex; + SizedHeap heap(num); + // Prepare search queues std::vector sch_curr; std::vector sch_next; sch_curr.reserve(self->population); sch_next.reserve(self->population); - for (const tir::Trace& trace : inits) { - sch_curr.push_back(self->_GetCachedTrace(trace)); + for (const CachedTrace& ctrace : inits) { + sch_curr.push_back(ctrace); } // Main loop: (genetic_algo_iters + 1) times for (int iter = 0;; ++iter) { - // Predict running time with the cost model, - // and put the schedules with the predicted perf to the heap + // Predict normalized score with the cost model, std::vector scores = - PredictNormalizedScore(sch_curr, self->tune_context_, self->cost_model, self->args_info_); - for (int i = 0, n = sch_curr.size(); i < n; ++i) { - CachedTrace& entry = sch_curr[i]; - entry.score = scores[i]; - if (!self->database->GetTopK(self->database->CommitWorkload(entry.sch->mod()), 1).size()) { - heap.Push(entry); + PredictNormalizedScore(sch_curr, self->tune_context_, self->cost_model_, self->args_info_); + for (int i = 0, n = sch_curr.size(); i < n; ++i) + if (sch_curr[i].defined()) { + CachedTrace& entry = sch_curr[i]; + entry.score = scores[i]; + Workload token = self->database_->CommitWorkload(entry.sch->mod()); + if (!self->database_->GetTopK(token, 1).size()) { + heap.Push(entry); + } } - } // Discontinue once it reaches end of search if (iter == self->genetic_algo_iters) { break; } // Set threaded samplers, with probability from predicated normalized throughputs for (int i = 0; i < self->num_threads_; ++i) { - TRandState& rand_state = per_thread_rand_state[i]; - thread_trace_samplers[i] = MakeMultinomial(rand_state, scores); - thread_mutator_samplers[i] = - MakeMutatorSampler(self->p_mutate, self->mutator_probs, rand_state); + per_thread_data_ex[i] = + PerThreadDataEx(&self->per_thread_data_[i], scores, self->p_mutate, self->mutator_probs_); } - trace_used = std::vector(scores.size(), 0); + ConcurrentBitmask cbmask(scores.size()); // The worker function - auto f_find_candidate = [&per_thread_rand_state, &thread_trace_samplers, - &thread_mutator_samplers, &trace_used, &trace_used_mutex, &sch_curr, - &sch_next, this](int thread_id, int i) { + auto f_find_candidate = [&per_thread_data_ex, &cbmask, &sch_curr, &sch_next, this]( + int thread_id, int trace_id) { // Prepare samplers - TRandState& rand_state = per_thread_rand_state[thread_id]; - const std::function& trace_sampler = thread_trace_samplers[thread_id]; + TRandState& rand_state = per_thread_data_ex[thread_id].rand_state; + const IRModule& mod = per_thread_data_ex[thread_id].mod; + const std::function& trace_sampler = per_thread_data_ex[thread_id].trace_sampler; const std::function()>& mutator_sampler = - thread_mutator_samplers[thread_id]; + per_thread_data_ex[thread_id].mutator_sampler; // Loop until success - int max_retry_cnt = 10; - int retry_cnt = 0; - for (;;) { - int trace_idx = trace_sampler(); - const CachedTrace& cached_trace = sch_curr[trace_idx]; + for (int retry_cnt = 0; retry_cnt <= self->max_evolve_fail_cnt; retry_cnt++) { + int sampled_trace_id = trace_sampler(); + const CachedTrace& ctrace = sch_curr[sampled_trace_id]; + // skip undefined cached trace + if (!ctrace.defined()) continue; if (Optional opt_mutator = mutator_sampler()) { // Decision: mutate Mutator mutator = opt_mutator.value(); - if (Optional opt_new_trace = - mutator->Apply(GetRef(cached_trace.trace))) { + if (Optional opt_new_trace = mutator->Apply(ctrace.trace)) { tir::Trace new_trace = opt_new_trace.value(); if (Optional opt_sch = - ReplayTrace(new_trace, self->mod_[i], &rand_state)) { + ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) { tir::Schedule sch = opt_sch.value(); - CachedTrace new_cached_trace{new_trace.get(), sch, Repr(sch), -1.0}; - self->_AddCachedTrace(new_cached_trace); - sch_next[i] = new_cached_trace; + CachedTrace new_ctrace{new_trace, sch, StructuralHash(sch), -1.0}; + sch_next[trace_id] = new_ctrace; break; } } } else { // Decision: do not mutate - std::unique_lock lock(trace_used_mutex); - if (!trace_used[trace_idx]) { - trace_used[trace_idx] = 1; - sch_next[i] = cached_trace; + if (cbmask.query_and_mark(sampled_trace_id)) { + sch_next[trace_id] = ctrace; break; } } - retry_cnt++; - if (retry_cnt >= max_retry_cnt) { - sch_next[i] = cached_trace; - break; - } } + // if retry count exceeds the limit, the result remain undefined and will not be used }; sch_next.clear(); - sch_next.resize(self->population); - support::parallel_for_dynamic(0, self->population, 1, f_find_candidate); + sch_next.resize(self->population, CachedTrace()); + support::parallel_for_dynamic(0, self->population, self->num_threads_, f_find_candidate); sch_curr.clear(); sch_curr.swap(sch_next); } // Return the best states from the heap, sorting from higher score to lower ones - std::sort(heap.heap.begin(), heap.heap.end(), CachedTrace::Compare); - Array results; - results.reserve(self->num_trials_per_iter); + std::sort(heap.heap.begin(), heap.heap.end()); + std::vector results; + results.reserve(num); for (const CachedTrace& item : heap.heap) { - results.push_back(GetRef(item.trace)); + results.push_back(item); } - /* Logging - constexpr int kNumScoresPerLine = 16; - std::ostringstream os; - int n = heap.heap.size(); - for (int st = 0; st < n; st += kNumScoresPerLine) { - os << std::endl; - int ed = std::min(st + kNumScoresPerLine, n); - os << "[" << (st + 1) << " : " << ed << "]:\t"; - for (int i = st; i < ed; ++i) { - if (i != st) { - os << " "; - } - os << std::fixed << std::setprecision(4) << heap.heap[i].score; + + constexpr int kNumScoresPerLine = 16; + std::ostringstream os; + int n = heap.heap.size(); + for (int st = 0; st < n; st += kNumScoresPerLine) { + os << std::endl; + int ed = std::min(st + kNumScoresPerLine, n); + os << "[" << (st + 1) << " : " << ed << "]:\t"; + for (int i = st; i < ed; ++i) { + if (i != st) { + os << " "; } + os << std::fixed << std::setprecision(4) << heap.heap[i].score; } - LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); - */ + } + LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); return results; } -Array EvolutionarySearchNode::State::PickWithEpsGreedy(const Array& inits, - const Array& bests) { - int num_rands = self->num_trials_per_iter * self->eps_greedy; - int num_bests = self->num_trials_per_iter - num_rands; +std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( + const std::vector& unmeasured, const std::vector& bests, int num) { + int num_rands = num * self->eps_greedy; + int num_bests = num - num_rands; std::vector rands = - tir::SampleWithoutReplacement(&self->rand_state_, inits.size(), inits.size()); - Array results; - results.reserve(self->num_trials_per_iter); - for (int i = 0, i_bests = 0, i_rands = 0; i < self->num_trials_per_iter; ++i) { + tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); + std::vector results; + results.reserve(num); + for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { bool has_best = i_bests < static_cast(bests.size()); bool has_rand = i_rands < static_cast(rands.size()); // Pick a schedule - Optional trace{NullOpt}; + CachedTrace ctrace; // If needs `bests`, then prefer `bests` if (i < num_bests) { if (has_best) { - trace = bests[i_bests++]; + ctrace = bests[i_bests++]; } else if (has_rand) { - trace = inits[rands[i_rands++]]; + ctrace = unmeasured[rands[i_rands++]]; } else { break; } } else { // Else prefer `rands` if (has_rand) { - trace = inits[rands[i_rands++]]; + ctrace = unmeasured[rands[i_rands++]]; } else if (has_best) { - trace = bests[i_bests++]; + ctrace = bests[i_bests++]; } else { break; } } - results.push_back(trace.value()); + results.push_back(ctrace); } return results; } -inline Array EvolutionarySearchNode::State::AssembleCandidates( - const Array& picks) { - Array measure_inputs; - measure_inputs.reserve(picks.size()); - for (const tir::Trace& pick : picks) { - CachedTrace trace = self->_GetCachedTrace(pick); - measure_inputs.push_back(MeasureCandidate(trace.sch, self->args_info_)); - } - return measure_inputs; -} - inline Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { if (st >= self->num_trials_total) { - self->candidates = Array(nullptr); return NullOpt; } + int sample_num = self->num_trials_per_iter; if (ed > self->num_trials_total) { - self->num_trials_per_iter += self->num_trials_total - ed; + sample_num += self->num_trials_total - ed; ed = self->num_trials_total; } ICHECK_LT(st, ed); - // new parts - Array inits = SampleInitPopulation(); - Array bests = EvolveWithCostModel(inits); - Array picks = PickWithEpsGreedy(inits, bests); - self->candidates = AssembleCandidates(picks); - return self->candidates; + std::vector measured = + PickBestFromDatabase(self->population * self->init_measured_ratio); + std::vector unmeasured = SampleInitPopulation(self->population - measured.size()); + std::vector inits = PruneAndMergeSamples(measured, unmeasured); + std::vector bests = EvolveWithCostModel(inits, sample_num); + std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); + return AssembleCandidates(picks, self->args_info_); } inline void EvolutionarySearchNode::State::NotifyRunnerResults(const Array& results) { - // We need to assume the candidates' order are not changed in runner. - ICHECK(self->candidates.defined() && self->candidates.size() == results.size()); st += results.size(); ed += results.size(); - int i = 0; - for (const RunnerResult& result : results) { - // Todo: Update to database measure callback - if (result->error_msg.defined() || !result->run_secs.defined()) continue; - self->database->CommitTuningRecord(TuningRecord( - /*trace=*/self->candidates[i]->sch->trace().value(), // - /*run_secs=*/result->run_secs.value(), // - /*workload=*/self->database->CommitWorkload(self->mod_[0]), // - /*target=*/self->target_, // - /*args_info=*/self->candidates[i]->args_info)); - // Todo: Update to cost model measure callback - self->cost_model->Update(self->tune_context_, self->candidates, results); - i++; - } + // Measure Callbacks done in TaskScheduler } -SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // - int num_trials_total, // - int population, // - double init_measured_ratio, // - int genetic_algo_iters, // - double p_mutate, // - double eps_greedy, // - Map mutator_probs, // - Database database, // +SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population, // + int max_replay_fail_cnt, // + double init_measured_ratio, // + int genetic_algo_iters, // + int max_evolve_fail_cnt, // + double p_mutate, // + double eps_greedy, // + Database database, // CostModel cost_model) { ObjectPtr n = make_object(); n->num_trials_per_iter = num_trials_per_iter; n->num_trials_total = num_trials_total; n->population = population; + n->max_replay_fail_cnt = max_replay_fail_cnt; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); n->init_measured_ratio = init_measured_ratio; n->genetic_algo_iters = genetic_algo_iters; + n->max_evolve_fail_cnt = max_evolve_fail_cnt; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(p_mutate, "Mutation probability"); n->p_mutate = p_mutate; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); n->eps_greedy = eps_greedy; - n->mutator_probs = mutator_probs; - n->database = database; - n->cost_model = cost_model; + n->database_ = database; + n->cost_model_ = cost_model; return SearchStrategy(n); } diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 4261b6f889..784ef6306e 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -30,7 +30,7 @@ TuneContext::TuneContext(Optional mod, Optional search_strategy, // Optional> sch_rules, // Optional> postprocs, // - Optional> mutators, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { @@ -41,7 +41,7 @@ TuneContext::TuneContext(Optional mod, n->search_strategy = search_strategy; n->sch_rules = sch_rules.value_or({}); n->postprocs = postprocs.value_or({}); - n->mutators = mutators.value_or({}); + n->mutator_probs = mutator_probs; n->task_name = task_name.value_or("main"); if (rand_state == -1) { rand_state = std::random_device()(); @@ -67,8 +67,10 @@ void TuneContextNode::Initialize() { for (const Postproc& postproc : postprocs) { postproc->InitializeWithTuneContext(GetRef(this)); } - for (const Mutator& mutator : mutators) { - mutator->InitializeWithTuneContext(GetRef(this)); + if (mutator_probs.defined()) { + for (const auto& kv : mutator_probs.value()) { + kv.first->InitializeWithTuneContext(GetRef(this)); + } } } @@ -81,12 +83,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional search_strategy, // Array sch_rules, // Array postprocs, // - Array mutators, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, - mutators, task_name, rand_state, num_threads); + mutator_probs, task_name, rand_state, num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 12982730a4..a6af196303 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -37,7 +37,6 @@ #include #include -#include #include #include @@ -251,8 +250,10 @@ inline int GetTargetNumCores(const Target& target) { ICHECK(f_cpu_count) << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\""; num_cores = (*f_cpu_count)(false); - LOG(FATAL) << "Target does not have attribute \"num_cores\", pyhsical core number must be " - "defined! Example: Local target ...."; + LOG(FATAL) + << "Target does not have attribute \"num-cores\", pyhsical core number must be " + "defined! For example, on the local machine, the target must be \"llvm -num-cores " + << num_cores << "\""; } return num_cores; } @@ -283,127 +284,6 @@ inline Optional ApplyTrace(const IRModule& mod, const tir::Trace& return sch; } -/*! - * \brief Get the string representation for the given schedule's IRModule. - * \param sch The given schedule. - * \return The string representation created. - */ -inline String Repr(const tir::Schedule& sch) { return tir::AsTVMScript(sch->mod()); } - -/*! - * \brief Create a sampling function that does multinomial sampling. - * \param rand_state The random state. - * \param weights The weights for multinomial sampling. - * \return The multinomial sampling function. - */ -inline std::function MakeMultinomial( - support::LinearCongruentialEngine::TRandState& rand_state, const std::vector& weights) { - std::vector sums; - sums.reserve(weights.size()); - double sum = 0.0; - for (double w : weights) { - sums.push_back(sum += w); - } - std::uniform_real_distribution dist(0.0, sum); - auto sampler = [rand_state, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { - support::LinearCongruentialEngine rand_(&rand_state); - double p = dist(rand_); - int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); - int n = sums.size(); - CHECK_LE(0, idx); - CHECK_LE(idx, n); - return (idx == n) ? (n - 1) : idx; - }; - return sampler; -} - -/*! - * \brief Create a sampler function that picks mutators according to the mass function - * \param rand_state The random state for sampling - * \return The sampler created - */ -inline std::function()> MakeMutatorSampler( - double p_mutate, const Map& mutator_probs, - support::LinearCongruentialEngine::TRandState& rand_state) { - CHECK(0.0 <= p_mutate && p_mutate <= 1.0) // - << "ValueError: Probability should be within [0, 1], " - << "but get `p_mutate = " << p_mutate << '\''; - std::vector> mutators; - std::vector masses; - mutators.push_back(NullOpt); - masses.push_back(1.0 - p_mutate); - double total_mass_mutator = 0.0; - for (const auto& kv : mutator_probs) { - const Mutator& mutator = kv.first; - double mass = kv.second->value; - CHECK_GE(mass, 0.0) << "ValueError: Probability of mutator '" << mutator - << "' is ill-formed: " << mass; - total_mass_mutator += mass; - mutators.push_back(kv.first); - masses.push_back(mass * p_mutate); - } - // Normalize the sum to 1.0 - if (total_mass_mutator == 0.0) { - masses[0] = 1.0; - for (int i = 1, n = masses.size(); i < n; ++i) { - masses[i] = 0.0; - } - } else if (total_mass_mutator != 1.0) { - for (int i = 1, n = masses.size(); i < n; ++i) { - masses[i] /= total_mass_mutator; - } - } - auto idx_sampler = MakeMultinomial(rand_state, masses); - return [idx_sampler = std::move(idx_sampler), - mutators = std::move(mutators)]() -> Optional { - int i = idx_sampler(); - return mutators[i]; - }; -} - -/*! \brief The postprocessed built of a trace */ -struct CachedTrace { - /*! \brief The trace */ - const tir::TraceNode* trace; - /*! \brief The schedule the trace creates */ - tir::Schedule sch; - /*! \brief The string representation of the schedule */ - String repr; - // Todo: Challenges in deduplication: remove unit loop / simplify pass - /*! \brief The normalized score, the higher the better */ - double score; - - static bool Compare(const CachedTrace& lhs, const CachedTrace& rhs) { - return lhs.score > rhs.score; - } -}; - -/*! - * \brief Predict the normalized score of each candidate. - * \param candidates The candidates for prediction - * \param task The search task - * \param space The search space - * \return The normalized score in the prediction - */ -inline std::vector PredictNormalizedScore(const std::vector& cached_traces, - const TuneContext& tune_context, - const CostModel& cost_model, - Array args_info) { - Array measure_inputs; - measure_inputs.reserve(cached_traces.size()); - for (const CachedTrace& cached_trace : cached_traces) { - measure_inputs.push_back(MeasureCandidate(cached_trace.sch, args_info)); - } - - std::vector scores = cost_model->Predict(tune_context, measure_inputs); - // Normalize the score - // TODO(@junrushao1994): use softmax + temperature to replace simple normalization to [0.0, +oo) - for (double& score : scores) { - score = std::max(0.0, score); - } - return scores; -} - } // namespace meta_schedule } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index f1dfed1644..2be6f5035a 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -227,11 +227,10 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mabi") .add_attr_option("system-lib") .add_attr_option("runtime") - .add_attr_option("num_cores") + .add_attr_option("num-cores") .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") - .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags .add_attr_option("fast-math") // implies all the below .add_attr_option("fast-math-nnan") diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e571896cd6..5ed88da65d 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1672,17 +1672,5 @@ bool HasIfThenElse(const Stmt& stmt) { return has_branch; } -bool CheckOneLine(const Stmt& s) { - bool legal = true, meet_block = false; - PostOrderVisit(s, [&legal, &meet_block](const ObjectRef& obj) { - if (obj->IsInstance() && !meet_block) { - legal = false; - } else if (obj->IsInstance()) { - meet_block = true; - } - }); - return legal; -} - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 26f6d7d594..1cb2a6df8b 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -30,7 +30,7 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->Seed(seed); return Schedule(std::move(n)); } diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0faad06492..a4d9d1f7b0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -22,11 +22,21 @@ #include #include +#include #include namespace tvm { namespace tir { +/*! + * \brief Create a sampling function that does multinomial sampling. + * \param rand_state The random state. + * \param weights The weights for multinomial sampling. + * \return The multinomial sampling function. + */ +TVM_DLL std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); + /******** Schedule: Sampling ********/ /*! * \brief Sample a random integer from a given range. diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 6061ca527e..2e56f12dba 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -86,6 +86,7 @@ struct PrimeTable { pow_tab.emplace_back(std::move(tab)); } } + /*! * \brief Factorize a number n, and return in a cryptic format * \param n The number to be factorized @@ -124,6 +125,28 @@ struct PrimeTable { } }; +std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights) { + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + std::uniform_real_distribution dist(0.0, sum); + auto sampler = [rand_state = support::LinearCongruentialEngine(rand_state).ForkSeed(), + dist = std::move(dist), sums = std::move(sums)]() mutable -> int32_t { + support::LinearCongruentialEngine rand_(&rand_state); + double p = dist(rand_); + int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int32_t n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; + return sampler; +} + int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive) { CHECK(min_inclusive < max_exclusive) diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 7b5c68d7d2..a5d044ca31 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -29,7 +29,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->Seed(seed); return Schedule(std::move(n)); } diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py index be041d543e..e57799f604 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -79,7 +79,6 @@ def _create_context(mod, target, rule): return ctx -# @pytest.mark.skip(reason="failing in staging branch @bohan") def test_parallel_vectorize_unroll(): expected = [ [ diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index a21481d5dd..4e51d497d0 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -80,10 +80,9 @@ def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) - # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming - i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) - j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) - k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4)) + j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4)) + k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2)) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) @@ -136,6 +135,12 @@ def __init__(self): self.records = [] self.workload_reg = [] + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + def commit_tuning_record(self, record: TuningRecord) -> None: self.records.append(record) @@ -216,17 +221,18 @@ def predict( population=5, init_measured_ratio=0.1, genetic_algo_iters=3, + max_evolve_fail_cnt=10, p_mutate=0.5, eps_greedy=0.9, - mutator_probs={mutator: 1.0}, database=database, cost_model=cost_model, ) tune_context = TuneContext( mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul), - mutators=[mutator], + mutator_probs={mutator: 1.0}, target=tvm.target.Target("llvm"), + num_threads=1, # beacuse we are using a mutator from the python side ) tune_context.space_generator.initialize_with_tune_context(tune_context) spaces = tune_context.space_generator.generate_design_space(tune_context.mod) @@ -250,10 +256,10 @@ def predict( candidates = strategy.generate_measure_candidates() strategy.post_tuning() print(num_trials_each_iter) - correct_count = 6 # For each iteration except the last one - assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + [ - num_trials_total % correct_count - ] + correct_count = 10 # For each iteration except the last one + assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + ( + [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else [] + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 438961f0b8..d3c4dbca82 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -138,6 +138,12 @@ def __init__(self): self.records = [] self.workload_reg = [] + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + def commit_tuning_record(self, record: TuningRecord) -> None: self.records.append(record) From c184547806908d9d798b0e7acce0f5d4252a6b03 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 5 Dec 2021 13:10:36 -0800 Subject: [PATCH 08/14] Fix TuneContext. --- include/tvm/meta_schedule/tune_context.h | 2 +- src/meta_schedule/tune_context.cc | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index d15b3b6123..eef7ae2b8d 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -49,7 +49,7 @@ class TuneContextNode : public runtime::Object { /*! \brief The postprocessors. */ Array postprocs; /*! \brief The probability of using certain mutator. */ - Optional> mutator_probs; + Map mutator_probs; /*! \brief The name of the tuning task. */ String task_name; /*! \brief The random state. */ diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 784ef6306e..2df6bee862 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -41,7 +41,7 @@ TuneContext::TuneContext(Optional mod, n->search_strategy = search_strategy; n->sch_rules = sch_rules.value_or({}); n->postprocs = postprocs.value_or({}); - n->mutator_probs = mutator_probs; + n->mutator_probs = mutator_probs.value_or({}); n->task_name = task_name.value_or("main"); if (rand_state == -1) { rand_state = std::random_device()(); @@ -68,7 +68,7 @@ void TuneContextNode::Initialize() { postproc->InitializeWithTuneContext(GetRef(this)); } if (mutator_probs.defined()) { - for (const auto& kv : mutator_probs.value()) { + for (const auto& kv : mutator_probs) { kv.first->InitializeWithTuneContext(GetRef(this)); } } @@ -81,8 +81,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional target, // Optional space_generator, // Optional search_strategy, // - Array sch_rules, // - Array postprocs, // + Optional> sch_rules, // + Optional> postprocs, // Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // From eb8a4aeca9044141d65c02c971c6a89e9d2fc337 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 5 Dec 2021 14:40:16 -0800 Subject: [PATCH 09/14] Fix haash & stuff. --- .../search_strategy/evolutionary_search.cc | 165 ++++++++++-------- 1 file changed, 94 insertions(+), 71 deletions(-) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index ab04143291..edca49cf34 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -26,6 +26,15 @@ namespace tvm { namespace meta_schedule { +/*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ +inline std::function()> MakeMutatorSampler( + double p_mutate, const Map& mutator_probs, + support::LinearCongruentialEngine::TRandState* rand_state); + /**************** Data Structure ****************/ /*! \brief The postprocessed built of a trace */ @@ -41,8 +50,21 @@ struct CachedTrace { /*! \brief The normalized score, the higher the better */ double score; + /*! + * \brief Get the structural hash for the given schedule's IRModule. + * \param sch The given schedule. + * \return The structural hash of the schedule's IRModule. + */ + inline CachedTrace::THashCode StructuralHash(const tir::Schedule& sch) { + return tvm::StructuralHash()(sch->mod()); + } + + CachedTrace() = default; + explicit CachedTrace(const tir::Trace& trace, const tir::Schedule& sch, double score) + : trace(trace), sch(sch), shash(StructuralHash(sch)), score(score) {} + inline bool defined() const { return trace.defined(); } - friend bool operator<(const CachedTrace& lhs, const CachedTrace& rhs) { + friend inline bool operator<(const CachedTrace& lhs, const CachedTrace& rhs) { return lhs.score > rhs.score; } }; @@ -55,6 +77,20 @@ struct CachedTrace { */ class SizedHeap { public: + struct IRModuleSHash { + IRModule mod; + CachedTrace::THashCode shash; + }; + + struct IRModuleSHashHash { + size_t operator()(const IRModuleSHash& hash) const { return hash.shash; } + }; + + struct IRModuleSHashEqual { + bool operator()(const IRModuleSHash& lhs, const IRModuleSHash& rhs) const { + return StructuralEqual()(lhs.mod, rhs.mod); + } + }; /*! * \brief Constructor * \param size_limit The up-limit of the heap size @@ -66,7 +102,7 @@ class SizedHeap { * \param item The item to be pushed */ void Push(const CachedTrace& item) { - if (!in_heap.insert(item.shash).second) { + if (!in_heap.insert(IRModuleSHash{item.sch->mod(), item.shash}).second) { return; } int size = heap.size(); @@ -88,7 +124,7 @@ class SizedHeap { /*! \brief The heap, the worse the topper */ std::vector heap; /*! \brief The traces that are in the heap */ - std::unordered_set in_heap; + std::unordered_set in_heap; }; struct PerThreadData { @@ -101,41 +137,54 @@ struct PerThreadDataEx { TRandState rand_state; std::function trace_sampler; std::function()> mutator_sampler; - explicit PerThreadDataEx() {} - explicit PerThreadDataEx(PerThreadData* data, const std::vector& scores, double p_mutate, - const Map& mutator_probs); + + /*! \brief Default constructor. */ + PerThreadDataEx() = default; + + /*! + * \brief Set the value for the trace and mutator samplers per thread. + * \param scores The predicted score for the given samples. + * \param p_mutate The probability of mutation. + * \param mutator_probs The probability of each mutator as a dict. + */ + void Set(const std::vector& scores, double p_mutate, + const Map& mutator_probs) { + trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); + mutator_sampler = MakeMutatorSampler(p_mutate, mutator_probs, &rand_state); + } }; struct ConcurrentBitmask { /*! The bit width. */ - const static int bitw = 64; + static constexpr const int kBitWidth = 64; /*! * \brief Constructor * \param size The size of the concurrent bitmask. */ explicit ConcurrentBitmask(int size) : size(size) { - bitmask.assign((size + bitw - 1) / bitw, 0); - std::vector list((size + bitw - 1) / bitw); - trace_used_mutex.swap(list); + bitmask.assign((size + kBitWidth - 1) / kBitWidth, 0); + std::vector list((size + kBitWidth - 1) / kBitWidth); + mutexes.swap(list); } /*! \brief The size of the concurrent bitmask. */ int size; /*! \brief The bitmasks. */ std::vector bitmask; - /*! \brief The mutexes, one per bitw(64 here) bitmasks. */ - std::vector trace_used_mutex; + /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ + std::vector mutexes; /*! * \brief Query and mark the given index if not visted before. * \param x The index to concurrently check if used. If not, mark as used. * \return Whether the index has been used before. */ - bool query_and_mark(int x) { + bool QueryAndMark(int x) { if (x < 0 || x >= size) return false; - std::unique_lock lock(trace_used_mutex[x / bitw]); - if (bitmask[x / bitw] & ((uint64_t)1 << (x % bitw))) { + std::unique_lock lock(mutexes[x / kBitWidth]); + constexpr uint64_t one = 1; + if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { return false; } else { - bitmask[x / bitw] |= (uint64_t)1 << (x % bitw); + bitmask[x / kBitWidth] |= one << (x % kBitWidth); return true; } } @@ -181,11 +230,6 @@ inline std::vector PredictNormalizedScore(const std::vector return scores; } -/*! - * \brief Create a sampler function that picks mutators according to the mass function - * \param rand_state The random state for sampling - * \return The sampler created - */ inline std::function()> MakeMutatorSampler( double p_mutate, const Map& mutator_probs, support::LinearCongruentialEngine::TRandState* rand_state) { @@ -223,30 +267,6 @@ inline std::function()> MakeMutatorSampler( }; } -/*! - * \brief Get the structural hash for the given schedule's IRModule. - * \param sch The given schedule. - * \return The structural hash of the schedule's IRModule. - */ -inline CachedTrace::THashCode StructuralHash(const tir::Schedule& sch) { - return tvm::StructuralHash()(sch->mod()); -} - -/*! - * \brief Constructor to create a more extensive per-thread data pack from original per-thread one. - * \param data The pointer to original per-thread data pack. - * \param scores The predicted score for the given samples. - * \param p_mutate The probability of mutation. - * \param mutator_probs The probability of each mutator as a dict. - */ -PerThreadDataEx::PerThreadDataEx(PerThreadData* data, const std::vector& scores, - double p_mutate, const Map& mutator_probs) { - rand_state = ForkSeed(&data->rand_state); - mod = data->mod; - trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); - mutator_sampler = MakeMutatorSampler(p_mutate, mutator_probs, &rand_state); -} - /**************** Evolutionary Search ****************/ /*! @@ -419,7 +439,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { this->tune_context_ = tune_context; this->target_ = tune_context->target.value(); this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); - if (p_mutate > 0) this->mutator_probs_ = tune_context->mutator_probs.value(); + this->mutator_probs_ = tune_context->mutator_probs; this->postprocs_ = tune_context->postprocs; this->num_threads_ = tune_context->num_threads; this->rand_state_ = ForkSeed(&tune_context->rand_state); @@ -476,7 +496,7 @@ inline std::vector EvolutionarySearchNode::State::PickBestFromDatab if (Optional opt_sch = meta_schedule::ApplyTrace(mod, trace, &rand_state, self->postprocs_)) { tir::Schedule sch = opt_sch.value(); - results[trace_id] = CachedTrace{trace, sch, StructuralHash(sch), -1.0}; + results[trace_id] = CachedTrace(trace, sch, -1.0); } else { LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; throw; @@ -493,7 +513,7 @@ inline std::vector EvolutionarySearchNode::State::SampleInitPopulat TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; const IRModule& mod = self->per_thread_data_[thread_id].mod; CachedTrace& result = results[trace_id]; - for (int fail_ct = 0; fail_ct <= self->max_replay_fail_cnt; fail_ct++) { + for (int fail_ct = 0; fail_ct < self->max_replay_fail_cnt; fail_ct++) { int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); tir::Trace trace = design_spaces[design_space_index]; if (Optional opt_sch = @@ -501,7 +521,7 @@ inline std::vector EvolutionarySearchNode::State::SampleInitPopulat ApplyTrace(mod, tir::Trace(trace->insts, {}), &rand_state, self->postprocs_)) { tir::Schedule sch = opt_sch.value(); tir::Trace trace = sch->trace().value(); - result = CachedTrace{trace, sch, StructuralHash(sch), -1.0}; + result = CachedTrace(trace, sch, -1.0); break; } } @@ -534,6 +554,10 @@ inline std::vector EvolutionarySearchNode::State::PruneAndMergeSamp std::vector EvolutionarySearchNode::State::EvolveWithCostModel( const std::vector& inits, int num) { std::vector per_thread_data_ex(self->num_threads_); + for (int i = 0; i < self->num_threads_; i++) { + per_thread_data_ex[i].mod = self->per_thread_data_[i].mod; + per_thread_data_ex[i].rand_state = ForkSeed(&self->per_thread_data_[i].rand_state); + } // The heap to record best schedule, we do not consider schedules that are already measured // Also we use `in_heap` to make sure items in the heap are de-duplicated SizedHeap heap(num); @@ -551,23 +575,20 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( // Predict normalized score with the cost model, std::vector scores = PredictNormalizedScore(sch_curr, self->tune_context_, self->cost_model_, self->args_info_); - for (int i = 0, n = sch_curr.size(); i < n; ++i) - if (sch_curr[i].defined()) { - CachedTrace& entry = sch_curr[i]; - entry.score = scores[i]; - Workload token = self->database_->CommitWorkload(entry.sch->mod()); - if (!self->database_->GetTopK(token, 1).size()) { - heap.Push(entry); - } + for (int i = 0, n = sch_curr.size(); i < n; ++i) { + CachedTrace& entry = sch_curr[i]; + entry.score = scores[i]; + if (!self->database_->HasWorkload(entry.sch->mod())) { + heap.Push(entry); } + } // Discontinue once it reaches end of search if (iter == self->genetic_algo_iters) { break; } // Set threaded samplers, with probability from predicated normalized throughputs for (int i = 0; i < self->num_threads_; ++i) { - per_thread_data_ex[i] = - PerThreadDataEx(&self->per_thread_data_[i], scores, self->p_mutate, self->mutator_probs_); + per_thread_data_ex[i].Set(scores, self->p_mutate, self->mutator_probs_); } ConcurrentBitmask cbmask(scores.size()); // The worker function @@ -579,12 +600,11 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( const std::function& trace_sampler = per_thread_data_ex[thread_id].trace_sampler; const std::function()>& mutator_sampler = per_thread_data_ex[thread_id].mutator_sampler; + CachedTrace& result = sch_next[trace_id]; // Loop until success - for (int retry_cnt = 0; retry_cnt <= self->max_evolve_fail_cnt; retry_cnt++) { + for (int retry_cnt = 0; retry_cnt < self->max_evolve_fail_cnt; retry_cnt++) { int sampled_trace_id = trace_sampler(); const CachedTrace& ctrace = sch_curr[sampled_trace_id]; - // skip undefined cached trace - if (!ctrace.defined()) continue; if (Optional opt_mutator = mutator_sampler()) { // Decision: mutate Mutator mutator = opt_mutator.value(); @@ -593,20 +613,23 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( if (Optional opt_sch = ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) { tir::Schedule sch = opt_sch.value(); - CachedTrace new_ctrace{new_trace, sch, StructuralHash(sch), -1.0}; - sch_next[trace_id] = new_ctrace; + // note that sch's trace is different from new_trace + // beacuase it contains post-processing infomation + CachedTrace new_ctrace(sch->trace().value(), sch, -1.0); + result = new_ctrace; break; } } - } else { + } else if (cbmask.QueryAndMark(sampled_trace_id)) { // Decision: do not mutate - if (cbmask.query_and_mark(sampled_trace_id)) { - sch_next[trace_id] = ctrace; - break; - } + result = ctrace; + break; + } + // if retry count exceeds the limit, the result should be just ctrace + if (retry_cnt + 1 == self->max_evolve_fail_cnt) { + sch_next[trace_id] = ctrace; } } - // if retry count exceeds the limit, the result remain undefined and will not be used }; sch_next.clear(); sch_next.resize(self->population, CachedTrace()); @@ -684,7 +707,7 @@ EvolutionarySearchNode::State::GenerateMeasureCandidates() { } int sample_num = self->num_trials_per_iter; if (ed > self->num_trials_total) { - sample_num += self->num_trials_total - ed; + sample_num = self->num_trials_total - st; ed = self->num_trials_total; } ICHECK_LT(st, ed); From eb86febb2a5951872bf50a5e0fad795784eebe28 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 6 Dec 2021 11:18:23 -0800 Subject: [PATCH 10/14] Modifyy shash. --- .../search_strategy/evolutionary_search.cc | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index edca49cf34..3c05aa1dcd 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -45,23 +45,12 @@ struct CachedTrace { tir::Trace trace{nullptr}; /*! \brief The schedule the trace creates */ tir::Schedule sch{nullptr}; - /*! \brief The structural hash of the schedule */ - THashCode shash; // todo(@zxybazh): deduplication /*! \brief The normalized score, the higher the better */ double score; - /*! - * \brief Get the structural hash for the given schedule's IRModule. - * \param sch The given schedule. - * \return The structural hash of the schedule's IRModule. - */ - inline CachedTrace::THashCode StructuralHash(const tir::Schedule& sch) { - return tvm::StructuralHash()(sch->mod()); - } - CachedTrace() = default; explicit CachedTrace(const tir::Trace& trace, const tir::Schedule& sch, double score) - : trace(trace), sch(sch), shash(StructuralHash(sch)), score(score) {} + : trace(trace), sch(sch), score(score) {} inline bool defined() const { return trace.defined(); } friend inline bool operator<(const CachedTrace& lhs, const CachedTrace& rhs) { @@ -88,7 +77,7 @@ class SizedHeap { struct IRModuleSHashEqual { bool operator()(const IRModuleSHash& lhs, const IRModuleSHash& rhs) const { - return StructuralEqual()(lhs.mod, rhs.mod); + return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); } }; /*! @@ -102,7 +91,7 @@ class SizedHeap { * \param item The item to be pushed */ void Push(const CachedTrace& item) { - if (!in_heap.insert(IRModuleSHash{item.sch->mod(), item.shash}).second) { + if (!in_heap.insert(IRModuleSHash{item.sch->mod(), StructuralHash()(item.sch->mod())}).second) { return; } int size = heap.size(); @@ -285,6 +274,8 @@ inline std::function()> MakeMutatorSampler( * chosen = pick top `k = num_measures_per_iter * (1 - eps_greedy)` from `best` * pick `k = num_measures_per_iter * eps_greedy ` from `init` * do the measurement on `chosen` & update the cost model + * + * Todo: (@zxybazh): Early stopping for small search space, including deduplication. */ class EvolutionarySearchNode : public SearchStrategyNode { public: From ece4c4f36ef6d6513e7dca753be3d5a08c0a600b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 6 Dec 2021 11:46:58 -0800 Subject: [PATCH 11/14] Remove trace field. --- .../search_strategy/evolutionary_search.cc | 57 ++++++++++++------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 3c05aa1dcd..f44c58e498 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -37,22 +37,40 @@ inline std::function()> MakeMutatorSampler( /**************** Data Structure ****************/ -/*! \brief The postprocessed built of a trace */ +/*! + * \brief The struct to store schedule, trace and its score. + * \note The trace is available by visiting the schedule's trace method. + */ struct CachedTrace { - /*! \brief The type of structural hash */ - using THashCode = size_t; - /*! \brief The trace */ - tir::Trace trace{nullptr}; - /*! \brief The schedule the trace creates */ + /*! \brief The schedule the trace creates. */ tir::Schedule sch{nullptr}; - /*! \brief The normalized score, the higher the better */ + /*! \brief The normalized score, the higher the better. */ double score; + /*! \brief Default constructor. */ CachedTrace() = default; - explicit CachedTrace(const tir::Trace& trace, const tir::Schedule& sch, double score) - : trace(trace), sch(sch), score(score) {} - - inline bool defined() const { return trace.defined(); } + /*! + * \brief Constructor from Schedule and score. + * \param sch The given Schedule, which can be used to obtain the trace. + * \param score The predicted normalized score, -1.0 if score is not assigned yet. + */ + explicit CachedTrace(const tir::Schedule& sch, double score) : sch(sch), score(score) {} + /*! + * \brief Check if the cached trace is defined. + * \return Whether the cached trace is defined. + */ + inline bool Defined() const { return sch.defined(); } + /*! + * \brief Get trace from a cached trace. + * \return The trace. + */ + inline tir::Trace GetTrace() const { + Optional trace; + ICHECK(sch.defined() && (trace = sch->trace())) + << "Schedule or trace is not defined when getting trace!"; + return trace.value(); + } + /*! \brief Reload the operator < for CachedTrace. */ friend inline bool operator<(const CachedTrace& lhs, const CachedTrace& rhs) { return lhs.score > rhs.score; } @@ -68,7 +86,7 @@ class SizedHeap { public: struct IRModuleSHash { IRModule mod; - CachedTrace::THashCode shash; + size_t shash; }; struct IRModuleSHashHash { @@ -487,7 +505,7 @@ inline std::vector EvolutionarySearchNode::State::PickBestFromDatab if (Optional opt_sch = meta_schedule::ApplyTrace(mod, trace, &rand_state, self->postprocs_)) { tir::Schedule sch = opt_sch.value(); - results[trace_id] = CachedTrace(trace, sch, -1.0); + results[trace_id] = CachedTrace(sch, -1.0); } else { LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; throw; @@ -511,12 +529,11 @@ inline std::vector EvolutionarySearchNode::State::SampleInitPopulat // replay trace, i.e., remove decisions ApplyTrace(mod, tir::Trace(trace->insts, {}), &rand_state, self->postprocs_)) { tir::Schedule sch = opt_sch.value(); - tir::Trace trace = sch->trace().value(); - result = CachedTrace(trace, sch, -1.0); + result = CachedTrace(sch, -1.0); break; } } - if (!result.defined()) { + if (!result.Defined()) { LOG(FATAL) << "Sample-Init-Population failed over the maximum limit!"; } }; @@ -530,12 +547,12 @@ inline std::vector EvolutionarySearchNode::State::PruneAndMergeSamp std::vector pruned; pruned.reserve(measured.size() + unmeasured.size()); for (const CachedTrace& entry : measured) { - if (entry.defined()) { + if (entry.Defined()) { pruned.push_back(entry); } } for (const CachedTrace& entry : unmeasured) { - if (entry.defined()) { + if (entry.Defined()) { pruned.push_back(entry); } } @@ -599,14 +616,14 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( if (Optional opt_mutator = mutator_sampler()) { // Decision: mutate Mutator mutator = opt_mutator.value(); - if (Optional opt_new_trace = mutator->Apply(ctrace.trace)) { + if (Optional opt_new_trace = mutator->Apply(ctrace.GetTrace())) { tir::Trace new_trace = opt_new_trace.value(); if (Optional opt_sch = ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) { tir::Schedule sch = opt_sch.value(); // note that sch's trace is different from new_trace // beacuase it contains post-processing infomation - CachedTrace new_ctrace(sch->trace().value(), sch, -1.0); + CachedTrace new_ctrace(sch, -1.0); result = new_ctrace; break; } From 3351e733cb0d10f869fee7f4704f4472757f06ca Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 6 Dec 2021 12:22:57 -0800 Subject: [PATCH 12/14] Minor fix. --- .../search_strategy/evolutionary_search.cc | 120 ++++++++---------- 1 file changed, 50 insertions(+), 70 deletions(-) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index f44c58e498..8bdfc379ac 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -26,15 +26,6 @@ namespace tvm { namespace meta_schedule { -/*! - * \brief Create a sampler function that picks mutators according to the mass function - * \param rand_state The random state for sampling - * \return The sampler created - */ -inline std::function()> MakeMutatorSampler( - double p_mutate, const Map& mutator_probs, - support::LinearCongruentialEngine::TRandState* rand_state); - /**************** Data Structure ****************/ /*! @@ -60,16 +51,6 @@ struct CachedTrace { * \return Whether the cached trace is defined. */ inline bool Defined() const { return sch.defined(); } - /*! - * \brief Get trace from a cached trace. - * \return The trace. - */ - inline tir::Trace GetTrace() const { - Optional trace; - ICHECK(sch.defined() && (trace = sch->trace())) - << "Schedule or trace is not defined when getting trace!"; - return trace.value(); - } /*! \brief Reload the operator < for CachedTrace. */ friend inline bool operator<(const CachedTrace& lhs, const CachedTrace& rhs) { return lhs.score > rhs.score; @@ -148,6 +129,48 @@ struct PerThreadDataEx { /*! \brief Default constructor. */ PerThreadDataEx() = default; + /*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ + inline std::function()> MakeMutatorSampler( + double p_mutate, const Map& mutator_probs, + support::LinearCongruentialEngine::TRandState* rand_state) { + std::vector> mutators; + std::vector masses; + mutators.push_back(NullOpt); + masses.push_back(1.0 - p_mutate); + double total_mass_mutator = 0.0; + if (p_mutate > 0) { + for (const auto& kv : mutator_probs) { + const Mutator& mutator = kv.first; + double mass = kv.second->value; + CHECK_GE(mass, 0.0) << "ValueError: Probability of mutator '" << mutator + << "' is ill-formed: " << mass; + total_mass_mutator += mass; + mutators.push_back(kv.first); + masses.push_back(mass * p_mutate); + } + } + // Normalize the sum to 1.0 + if (total_mass_mutator == 0.0) { + masses[0] = 1.0; + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] = 0.0; + } + } else if (total_mass_mutator != 1.0) { + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] /= total_mass_mutator; + } + } + return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), + mutators = std::move(mutators)]() -> Optional { + int i = idx_sampler(); + return mutators[i]; + }; + } + /*! * \brief Set the value for the trace and mutator samplers per thread. * \param scores The predicted score for the given samples. @@ -164,21 +187,19 @@ struct PerThreadDataEx { struct ConcurrentBitmask { /*! The bit width. */ static constexpr const int kBitWidth = 64; - /*! - * \brief Constructor - * \param size The size of the concurrent bitmask. - */ - explicit ConcurrentBitmask(int size) : size(size) { - bitmask.assign((size + kBitWidth - 1) / kBitWidth, 0); - std::vector list((size + kBitWidth - 1) / kBitWidth); - mutexes.swap(list); - } /*! \brief The size of the concurrent bitmask. */ int size; /*! \brief The bitmasks. */ std::vector bitmask; /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ std::vector mutexes; + + /*! + * \brief Constructor + * \param n The total slots managed by the concurrent bitmask. + */ + explicit ConcurrentBitmask(int n) + : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {} /*! * \brief Query and mark the given index if not visted before. * \param x The index to concurrently check if used. If not, mark as used. @@ -237,43 +258,6 @@ inline std::vector PredictNormalizedScore(const std::vector return scores; } -inline std::function()> MakeMutatorSampler( - double p_mutate, const Map& mutator_probs, - support::LinearCongruentialEngine::TRandState* rand_state) { - std::vector> mutators; - std::vector masses; - mutators.push_back(NullOpt); - masses.push_back(1.0 - p_mutate); - double total_mass_mutator = 0.0; - if (p_mutate > 0) { - for (const auto& kv : mutator_probs) { - const Mutator& mutator = kv.first; - double mass = kv.second->value; - CHECK_GE(mass, 0.0) << "ValueError: Probability of mutator '" << mutator - << "' is ill-formed: " << mass; - total_mass_mutator += mass; - mutators.push_back(kv.first); - masses.push_back(mass * p_mutate); - } - } - // Normalize the sum to 1.0 - if (total_mass_mutator == 0.0) { - masses[0] = 1.0; - for (int i = 1, n = masses.size(); i < n; ++i) { - masses[i] = 0.0; - } - } else if (total_mass_mutator != 1.0) { - for (int i = 1, n = masses.size(); i < n; ++i) { - masses[i] /= total_mass_mutator; - } - } - return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), - mutators = std::move(mutators)]() -> Optional { - int i = idx_sampler(); - return mutators[i]; - }; -} - /**************** Evolutionary Search ****************/ /*! @@ -297,8 +281,6 @@ inline std::function()> MakeMutatorSampler( */ class EvolutionarySearchNode : public SearchStrategyNode { public: - using TRandState = support::LinearCongruentialEngine::TRandState; - /*! \brief The state of the search strategy. */ struct State { /*! \brief The search strategy itself */ @@ -441,8 +423,6 @@ class EvolutionarySearchNode : public SearchStrategyNode { void InitializeWithTuneContext(const TuneContext& tune_context) final { CHECK(tune_context.defined()) << "TuneContext must be defined!"; CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; - CHECK(p_mutate == 0 || tune_context->mutator_probs.defined()) - << "Mutators and their probabilities must be defined given mutation probability is not 0!"; CHECK(tune_context->target.defined()) << "Target must be defined!"; this->tune_context_ = tune_context; @@ -616,7 +596,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( if (Optional opt_mutator = mutator_sampler()) { // Decision: mutate Mutator mutator = opt_mutator.value(); - if (Optional opt_new_trace = mutator->Apply(ctrace.GetTrace())) { + if (Optional opt_new_trace = mutator->Apply(ctrace.sch->trace().value())) { tir::Trace new_trace = opt_new_trace.value(); if (Optional opt_sch = ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) { From 9f1db1e2c17653582dda648306629fc10e2fa0e6 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 6 Dec 2021 12:29:59 -0800 Subject: [PATCH 13/14] Fix cbmask. --- src/meta_schedule/search_strategy/evolutionary_search.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 8bdfc379ac..dd2252f478 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -206,7 +206,6 @@ struct ConcurrentBitmask { * \return Whether the index has been used before. */ bool QueryAndMark(int x) { - if (x < 0 || x >= size) return false; std::unique_lock lock(mutexes[x / kBitWidth]); constexpr uint64_t one = 1; if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { From 1949b047ffd502faa5bf283e456c04ba9cb98961 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 6 Dec 2021 15:09:07 -0800 Subject: [PATCH 14/14] Fix numbers. --- .../search_strategy/evolutionary_search.cc | 83 +++++++------------ 1 file changed, 30 insertions(+), 53 deletions(-) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index dd2252f478..070df7b487 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -46,11 +46,6 @@ struct CachedTrace { * \param score The predicted normalized score, -1.0 if score is not assigned yet. */ explicit CachedTrace(const tir::Schedule& sch, double score) : sch(sch), score(score) {} - /*! - * \brief Check if the cached trace is defined. - * \return Whether the cached trace is defined. - */ - inline bool Defined() const { return sch.defined(); } /*! \brief Reload the operator < for CachedTrace. */ friend inline bool operator<(const CachedTrace& lhs, const CachedTrace& rhs) { return lhs.score > rhs.score; @@ -118,16 +113,13 @@ class SizedHeap { struct PerThreadData { IRModule mod; TRandState rand_state; -}; - -struct PerThreadDataEx { - IRModule mod; - TRandState rand_state; std::function trace_sampler; std::function()> mutator_sampler; /*! \brief Default constructor. */ - PerThreadDataEx() = default; + PerThreadData() = default; + explicit PerThreadData(const IRModule& mod, TRandState* rand_state) + : mod(mod), rand_state(ForkSeed(rand_state)) {} /*! * \brief Create a sampler function that picks mutators according to the mass function @@ -313,8 +305,8 @@ class EvolutionarySearchNode : public SearchStrategyNode { * \param unmeasured Unmeasured samples from replaying traces from design space. * \return The merged results, excluding undefined samples. */ - inline std::vector PruneAndMergeSamples( - const std::vector& measured, const std::vector& unmeasured); + inline std::vector MergeSamples(const std::vector& measured, + const std::vector& unmeasured); /*! * \brief Evolve the initial population using mutators and samplers. * \param inits The initial population of traces sampled. @@ -435,7 +427,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { this->per_thread_data_.reserve(this->num_threads_); for (int i = 0; i < this->num_threads_; i++) { this->per_thread_data_.push_back( - PerThreadData{DeepCopyIRModule(tune_context->mod.value()), ForkSeed(&this->rand_state_)}); + PerThreadData(DeepCopyIRModule(tune_context->mod.value()), &this->rand_state_)); } this->state_.reset(); } @@ -471,12 +463,12 @@ class EvolutionarySearchNode : public SearchStrategyNode { inline std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { std::vector measured_traces; measured_traces.reserve(num); - std::vector results(num, CachedTrace()); - - for (TuningRecord record : self->database_->GetTopK(self->token_, num)) { + Array top_records = self->database_->GetTopK(self->token_, num); + for (TuningRecord record : top_records) { measured_traces.push_back(record->trace); } - + int acutal_num = measured_traces.size(); + std::vector results(acutal_num); auto f_proc_measured = [this, &measured_traces, &results](int thread_id, int trace_id) -> void { TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; const IRModule& mod = self->per_thread_data_[thread_id].mod; @@ -490,13 +482,13 @@ inline std::vector EvolutionarySearchNode::State::PickBestFromDatab throw; } }; - support::parallel_for_dynamic(0, measured_traces.size(), self->num_threads_, f_proc_measured); + support::parallel_for_dynamic(0, acutal_num, self->num_threads_, f_proc_measured); return results; } inline std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { // Pick unmeasured states - std::vector results; + std::vector results(num); auto f_proc_unmeasured = [this, &results, &num](int thread_id, int trace_id) -> void { TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; const IRModule& mod = self->per_thread_data_[thread_id].mod; @@ -512,39 +504,27 @@ inline std::vector EvolutionarySearchNode::State::SampleInitPopulat break; } } - if (!result.Defined()) { + if (!result.sch.defined()) { LOG(FATAL) << "Sample-Init-Population failed over the maximum limit!"; } }; - results.resize(num, CachedTrace()); support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); return results; } -inline std::vector EvolutionarySearchNode::State::PruneAndMergeSamples( +inline std::vector EvolutionarySearchNode::State::MergeSamples( const std::vector& measured, const std::vector& unmeasured) { - std::vector pruned; - pruned.reserve(measured.size() + unmeasured.size()); - for (const CachedTrace& entry : measured) { - if (entry.Defined()) { - pruned.push_back(entry); - } - } - for (const CachedTrace& entry : unmeasured) { - if (entry.Defined()) { - pruned.push_back(entry); - } - } - return pruned; + ICHECK(measured.size() + unmeasured.size() == self->population) + << "Num of total init samples does not equal to population size!"; + std::vector inits; + inits.reserve(self->population); + inits.insert(inits.end(), measured.begin(), measured.end()); + inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); + return inits; } std::vector EvolutionarySearchNode::State::EvolveWithCostModel( const std::vector& inits, int num) { - std::vector per_thread_data_ex(self->num_threads_); - for (int i = 0; i < self->num_threads_; i++) { - per_thread_data_ex[i].mod = self->per_thread_data_[i].mod; - per_thread_data_ex[i].rand_state = ForkSeed(&self->per_thread_data_[i].rand_state); - } // The heap to record best schedule, we do not consider schedules that are already measured // Also we use `in_heap` to make sure items in the heap are de-duplicated SizedHeap heap(num); @@ -575,18 +555,17 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( } // Set threaded samplers, with probability from predicated normalized throughputs for (int i = 0; i < self->num_threads_; ++i) { - per_thread_data_ex[i].Set(scores, self->p_mutate, self->mutator_probs_); + self->per_thread_data_[i].Set(scores, self->p_mutate, self->mutator_probs_); } ConcurrentBitmask cbmask(scores.size()); // The worker function - auto f_find_candidate = [&per_thread_data_ex, &cbmask, &sch_curr, &sch_next, this]( - int thread_id, int trace_id) { + auto f_find_candidate = [&cbmask, &sch_curr, &sch_next, this](int thread_id, int trace_id) { // Prepare samplers - TRandState& rand_state = per_thread_data_ex[thread_id].rand_state; - const IRModule& mod = per_thread_data_ex[thread_id].mod; - const std::function& trace_sampler = per_thread_data_ex[thread_id].trace_sampler; + TRandState& rand_state = self->per_thread_data_[thread_id].rand_state; + const IRModule& mod = self->per_thread_data_[thread_id].mod; + const std::function& trace_sampler = self->per_thread_data_[thread_id].trace_sampler; const std::function()>& mutator_sampler = - per_thread_data_ex[thread_id].mutator_sampler; + self->per_thread_data_[thread_id].mutator_sampler; CachedTrace& result = sch_next[trace_id]; // Loop until success for (int retry_cnt = 0; retry_cnt < self->max_evolve_fail_cnt; retry_cnt++) { @@ -599,11 +578,9 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( tir::Trace new_trace = opt_new_trace.value(); if (Optional opt_sch = ApplyTrace(mod, new_trace, &rand_state, self->postprocs_)) { - tir::Schedule sch = opt_sch.value(); // note that sch's trace is different from new_trace // beacuase it contains post-processing infomation - CachedTrace new_ctrace(sch, -1.0); - result = new_ctrace; + result = CachedTrace(opt_sch.value(), -1.0); break; } } @@ -619,7 +596,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( } }; sch_next.clear(); - sch_next.resize(self->population, CachedTrace()); + sch_next.resize(self->population); support::parallel_for_dynamic(0, self->population, self->num_threads_, f_find_candidate); sch_curr.clear(); sch_curr.swap(sch_next); @@ -702,7 +679,7 @@ EvolutionarySearchNode::State::GenerateMeasureCandidates() { std::vector measured = PickBestFromDatabase(self->population * self->init_measured_ratio); std::vector unmeasured = SampleInitPopulation(self->population - measured.size()); - std::vector inits = PruneAndMergeSamples(measured, unmeasured); + std::vector inits = MergeSamples(measured, unmeasured); std::vector bests = EvolveWithCostModel(inits, sample_num); std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); return AssembleCandidates(picks, self->args_info_);