From 50f8d9987050f68b3e348f935a68f44f4db9a933 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 4 Jan 2022 20:44:18 -0800 Subject: [PATCH] Add evolutionary search. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- .../meta_schedule/search_strategy/__init__.py | 1 + .../search_strategy/evolutionary_search.py | 117 +++ .../search_strategy/evolutionary_search.cc | 673 ++++++++++++++++++ src/tir/schedule/primitive.h | 17 + src/tir/schedule/primitive/sampling.cc | 22 + .../test_meta_schedule_search_strategy.py | 173 ++++- 6 files changed, 997 insertions(+), 6 deletions(-) create mode 100644 python/tvm/meta_schedule/search_strategy/evolutionary_search.py create mode 100644 src/meta_schedule/search_strategy/evolutionary_search.cc diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index f385b72db46d7..174672235b426 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -23,3 +23,4 @@ from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate from .replay_trace import ReplayTrace, ReplayTraceConfig from .replay_func import ReplayFunc, ReplayFuncConfig +from .evolutionary_search import EvolutionarySearch, EvolutionarySearchConfig diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py new file mode 100644 index 0000000000000..a679c19709511 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Evolutionary Search Strategy""" + +from typing import NamedTuple + +from tvm._ffi import register_object + +from .. import _ffi_api +from .search_strategy import SearchStrategy + + +@register_object("meta_schedule.EvolutionarySearch") +class EvolutionarySearch(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + population_size : int + The initial population of traces from measured samples and randomly generated samples. + init_measured_ratio : int + The ratio of measured samples in the initial population. + init_max_fail_count : int + The maximum number to fail trace replaying. + genetic_num_iters : int + The number of iterations for genetic algorithm. + genetic_mutate_prob : float + The probability of mutation. + genetic_max_fail_count : int + The maximum number to retry mutation. + eps_greedy : float + The ratio of greedy selected samples in the final picks. + """ + + num_trials_per_iter: int + num_trials_total: int + population_size: int + init_measured_ratio: int + init_max_fail_count: int + genetic_num_iters: int + genetic_mutate_prob: float + genetic_max_fail_count: int + eps_greedy: float + + def __init__( + self, + *, + num_trials_per_iter: int, + num_trials_total: int, + population_size: int, + init_measured_ratio: float, + init_max_fail_count: int, + genetic_num_iters: int, + genetic_mutate_prob: float, + genetic_max_fail_count: int, + eps_greedy: float, + ) -> None: + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + population_size, + init_measured_ratio, + init_max_fail_count, + genetic_num_iters, + genetic_mutate_prob, + genetic_max_fail_count, + eps_greedy, + ) + + +class EvolutionarySearchConfig(NamedTuple): + """Configuration for EvolutionarySearch""" + + num_trials_per_iter: int + num_trials_total: int + population_size: int = 2048 + init_measured_ratio: float = 0.2 + init_max_fail_count: int = 64 + genetic_num_iters: int = 4 + genetic_mutate_prob: float = 0.85 + genetic_max_fail_count: int = 10 + eps_greedy: float = 0.05 + + def create_strategy(self) -> EvolutionarySearch: + return EvolutionarySearch( + num_trials_per_iter=self.num_trials_per_iter, + num_trials_total=self.num_trials_total, + population_size=self.population_size, + init_measured_ratio=self.init_measured_ratio, + init_max_fail_count=self.init_max_fail_count, + genetic_num_iters=self.genetic_num_iters, + genetic_mutate_prob=self.genetic_mutate_prob, + genetic_max_fail_count=self.genetic_max_fail_count, + eps_greedy=self.eps_greedy, + ) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc new file mode 100644 index 0000000000000..cb35406c1d8f1 --- /dev/null +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -0,0 +1,673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ + CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ + << "but get `" << #p << " = " << (p) << '\''; + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +/**************** Data Structure ****************/ + +/*! + * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. + * \note It maintains a min heap in terms of `Item::score`. Therefore, when + * overflow happens, the element evicted is the one with the min `Item::score`. + * As time goes, the elements in the heap are going to be larger. + */ +class SizedHeap { + public: + struct Item { + Schedule sch; + IRModule mod; + size_t shash; + double score; + bool operator<(const Item& other) const { return score > other.score; } + }; + + struct ItemHash { + size_t operator()(const Item& hash) const { return hash.shash; } + }; + + struct ItemEqual { + bool operator()(const Item& lhs, const Item& rhs) const { + return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); + } + }; + /*! + * \brief Constructor + * \param size_limit The up-limit of the heap size + */ + explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); } + + /*! + * \brief Push the specific item to the heap if its key did not appears in the heap + * \param item The item to be pushed + */ + void Push(Schedule sch, IRModule mod, double score) { + Item item{sch, mod, StructuralHash()(mod), score}; + if (!in_heap.insert(item).second) { + return; + } + int size = heap.size(); + if (size < size_limit) { + // Heap is not full, just push + heap.emplace_back(item); + std::push_heap(heap.begin(), heap.end()); + } else if (item.score > heap.front().score) { + // if the item is better than the worst one in the heap, we can safely kick it out + std::pop_heap(heap.begin(), heap.end()); + heap.back() = item; + std::push_heap(heap.begin(), heap.end()); + } + // Otherwise, the item is worse than any other element in the heap + } + + /*! \brief Up-limit of the heap size */ + int size_limit; + /*! \brief The heap, the worse the topper */ + std::vector heap; + /*! \brief The traces that are in the heap */ + std::unordered_set in_heap; +}; + +struct PerThreadData { + IRModule mod{nullptr}; + TRandState rand_state{-1}; + std::function trace_sampler = nullptr; + std::function()> mutator_sampler = nullptr; + + /*! + * \brief Set the value for the trace and mutator samplers per thread. + * \param scores The predicted score for the given samples. + * \param genetic_mutate_prob The probability of mutation. + * \param mutator_probs The probability of each mutator as a dict. + */ + void Set(const std::vector& scores, double genetic_mutate_prob, + const Map& mutator_probs) { + trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); + mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); + } + + private: + /*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ + static std::function()> MakeMutatorSampler( + double genetic_mutate_prob, // + const Map& mutator_probs, // + TRandState* rand_state) { + std::vector> mutators; + std::vector masses; + mutators.push_back(NullOpt); + masses.push_back(1.0 - genetic_mutate_prob); + double total_mass_mutator = 0.0; + if (genetic_mutate_prob > 0) { + for (const auto& kv : mutator_probs) { + Mutator mutator = kv.first; + double mass = kv.second->value; + total_mass_mutator += mass; + mutators.push_back(mutator); + masses.push_back(mass * genetic_mutate_prob); + } + } + // Normalize the sum to 1.0 + if (total_mass_mutator == 0.0) { + masses[0] = 1.0; + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] = 0.0; + } + } else if (total_mass_mutator != 1.0) { + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] /= total_mass_mutator; + } + } + return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), + mutators = std::move(mutators)]() -> Optional { + int i = idx_sampler(); + return mutators[i]; + }; + } +}; + +struct ConcurrentBitmask { + /*! The bit width. */ + static constexpr const int kBitWidth = 64; + /*! \brief The size of the concurrent bitmask. */ + int size; + /*! \brief The bitmasks. */ + std::vector bitmask; + /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ + std::vector mutexes; + + /*! + * \brief Constructor + * \param n The total slots managed by the concurrent bitmask. + */ + explicit ConcurrentBitmask(int n) + : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {} + /*! + * \brief Query and mark the given index if not visited before. + * \param x The index to concurrently check if used. If not, mark as used. + * \return Whether the index has been used before. + */ + bool QueryAndMark(int x) { + constexpr uint64_t one = 1; + std::unique_lock lock(mutexes[x / kBitWidth]); + if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { + return false; + } else { + bitmask[x / kBitWidth] |= one << (x % kBitWidth); + return true; + } + } +}; + +/**************** Util Functions ****************/ + +/*! + * \brief Assemble measure candidates from the given candidate traces. + * \param traces The picked candidate traces. + * \return The assembled measure candidates. + */ +Array AssembleCandidates(const std::vector& picks, + const Array& args_info) { + Array measure_inputs; + measure_inputs.reserve(picks.size()); + for (const Schedule& sch : picks) { + measure_inputs.push_back(MeasureCandidate(sch, args_info)); + } + return measure_inputs; +} + +/*! + * \brief Predict the normalized score of each candidate. + * \param candidates The candidates for prediction + * \param task The search task + * \param space The search space + * \return The normalized score in the prediction + */ +std::vector PredictNormalizedScore(const std::vector& candidates, + const TuneContext& context, const CostModel& cost_model, + const Array& args_info) { + ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!"; + std::vector scores = + cost_model->Predict(context, AssembleCandidates(candidates, args_info)); + for (double& score : scores) { + score = std::max(0.0, score); + } + return scores; +} + +/**************** Evolutionary Search ****************/ + +/*!\brief A search strategy that generates measure candidates using evolutionary search. */ +class EvolutionarySearchNode : public SearchStrategyNode { + public: + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + EvolutionarySearchNode* self; + /*! \brief The design spaces. Decisions are not used so traces only. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(EvolutionarySearchNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + /*! + * \brief Pick up best candidates from database. + * \param num The number of traces to produce. + * \return The picked best candidates. + */ + inline std::vector PickBestFromDatabase(int num); + /*! + * \brief Sample the initial population from previous measured results and randomly generated + * traces via trace replaying. + * \param num The number of traces to produce. + * \return The initial population of traces sampled. + */ + inline std::vector SampleInitPopulation(int num); + /*! + * \brief Evolve the initial population using mutators and samplers. + * \param population The initial population of traces sampled. + * \param num The number of traces to produce. + * \return The evolved traces from initial population. + */ + inline std::vector EvolveWithCostModel(std::vector population, int num); + /*! + * \brief Pick final candidates from the given initial population and bests of evolved ones. + * \param inits The initial population of traces sampled. + * \param bests The best candidates predicted from evolved traces. + * \param num The number of traces to produce. + * \return The final picked candidates with a ratio of both. + */ + inline std::vector PickWithEpsGreedy(const std::vector& inits, + const std::vector& bests, int num); + /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ + inline Optional> GenerateMeasureCandidates(); + /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ + inline void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results); + }; + + /*! \brief The tuning context of the evolutionary search strategy. */ + const TuneContextNode* context_{nullptr}; + /*! \brief The target for the workload. */ + Target target_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief A Database for selecting useful candidates. */ + Database database_{nullptr}; + /*! \brief A cost model helping to explore the search space */ + CostModel cost_model_{nullptr}; + /*! \brief The postprocessors. */ + Array postprocs_{nullptr}; + /*! \brief Mutators and their probability mass */ + Map mutator_probs_{nullptr}; + /*! \brief The number of threads to use. To be initialized with TuneContext. */ + int num_threads_; + /*! \brief The random state. To be initialized with TuneContext. */ + TRandState rand_state_; + /*! \brief Pre thread data including module to be tuned and random state. */ + std::vector per_thread_data_; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + /*! \brief The token registered for the given workload in database. */ + Workload token_{nullptr}; + + /*** Configuration: global ***/ + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + /*! \brief The population size in the evolutionary search. */ + int population_size; + /*** Configuration: the initial population ***/ + /*! \brief The ratio of measured states used in the initial population */ + double init_measured_ratio; + /*! \brief The maximum number to fail trace replaying. */ + int init_max_fail_count; + /*** Configuration: evolution ***/ + /*! \brief The number of iterations performed by generic algorithm. */ + int genetic_num_iters; + /*! \brief The probability to perform mutation */ + double genetic_mutate_prob; + /*! \brief The maximum number to try evolving the given trace. */ + int genetic_max_fail_count; + /*** Configuration: pick states for measurement ***/ + /*! \brief The ratio of measurements to use randomly sampled states. */ + double eps_greedy; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `context_` is not visited + // `target_` is not visited + // `args_info_` is not visited + // `database` is not visited + // `cost_model` is not visited + // `postprocs` is not visited + // `mutator_probs_` is not visited + // `num_threads` is not visited + // `rand_state_` is not visited + // `per_thread_data_` is not visited + // `state_` is not visited + + /*** Configuration: global ***/ + v->Visit("num_trials_total", &num_trials_total); + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("population_size", &population_size); + /*** Configuration: the initial population ***/ + v->Visit("init_measured_ratio", &init_measured_ratio); + v->Visit("init_max_fail_count", &init_max_fail_count); + /*** Configuration: evolution ***/ + v->Visit("genetic_num_iters", &genetic_num_iters); + v->Visit("genetic_mutate_prob", &genetic_mutate_prob); + v->Visit("genetic_max_fail_count", &genetic_max_fail_count); + /*** Configuration: pick states for measurement ***/ + v->Visit("eps_greedy", &eps_greedy); + } + + static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; + TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context.defined()) << "TuneContext must be defined!"; + CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0."; + CHECK(context->target.defined()) << "Target must be defined!"; + this->context_ = context.get(); + this->target_ = context->target.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); + this->mutator_probs_ = context->mutator_probs; + this->postprocs_ = context->postprocs; + this->num_threads_ = context->num_threads; + this->rand_state_ = ForkSeed(&context->rand_state); + this->cost_model_ = context->task_scheduler->cost_model.value(); + this->database_ = context->task_scheduler->database; + this->token_ = this->database_->CommitWorkload(context->mod.value()); + this->per_thread_data_.resize(this->num_threads_); + for (const auto& kv : this->mutator_probs_) { + double mass = kv.second->value; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs"); + } + for (PerThreadData& data : this->per_thread_data_) { + data.mod = DeepCopyIRModule(context->mod.value()); + data.rand_state = ForkSeed(&this->rand_state_); + } + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + // Change to traces + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(context, measure_candidates, results); + } +}; + +std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { + std::vector measured_traces; + measured_traces.reserve(num); + Array top_records = self->database_->GetTopK(self->token_, num); + for (TuningRecord record : top_records) { + measured_traces.push_back(record->trace); + } + int actual_num = measured_traces.size(); + ThreadedTraceApply pp(self->postprocs_); + std::vector results(actual_num, Schedule{nullptr}); + auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id, + int trace_id) -> void { + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + tir::Trace trace = measured_traces.at(trace_id); + Schedule& result = results.at(trace_id); + ICHECK(!result.defined()); + if (Optional sch = pp.Apply(mod, trace, rand_state)) { + result = sch.value(); + } else { + LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; + throw; + } + }; + support::parallel_for_dynamic(0, actual_num, self->num_threads_, f_proc_measured); + return results; +} + +std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { + ThreadedTraceApply pp(self->postprocs_); + std::vector results(num, Schedule{nullptr}); + auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void { + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + Schedule& result = results.at(trace_id); + ICHECK(!result.defined()); + for (int fail_count = 0; fail_count <= self->init_max_fail_count; ++fail_count) { + int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); + tir::Trace trace(design_spaces[design_space_index]->insts, {}); + if (Optional sch = pp.Apply(mod, trace, rand_state)) { + result = sch.value(); + break; + } + } + if (!result.defined()) { + LOG(FATAL) << "Sample-Init-Population failed over the maximum limit! Summary:\n" + << pp.SummarizeFailures(); + } + }; + support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); + LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures(); + return results; +} + +std::vector EvolutionarySearchNode::State::EvolveWithCostModel( + std::vector population, int num) { + ICHECK_GT(num, 0); + // The heap to record best schedule, we do not consider schedules that are already measured + // Also we use `in_heap` to make sure items in the heap are de-duplicated + SizedHeap heap(num); + for (int iter = 0;; ++iter) { + // Predict normalized score with the cost model, + std::vector scores = PredictNormalizedScore(population, // + GetRef(self->context_), // + self->cost_model_, // + self->args_info_); + ICHECK_EQ(scores.size(), population.size()); + for (int i = 0, n = population.size(); i < n; ++i) { + Schedule sch = population.at(i); + IRModule mod = sch->mod(); + double score = scores.at(i); + if (!self->database_->HasWorkload(mod)) { + heap.Push(sch, mod, score); + } + } + // Discontinue once it reaches end of search + if (iter == self->genetic_num_iters) { + break; + } + // Set threaded samplers, with probability from predicated normalized throughputs + for (PerThreadData& data : self->per_thread_data_) { + data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_); + } + ThreadedTraceApply pp(self->postprocs_); + ConcurrentBitmask cbmask(self->population_size); + std::vector next_population(self->population_size, Schedule{nullptr}); + // The worker function + auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id, + int trace_id) { + // Prepare samplers + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + std::function& trace_sampler = data.trace_sampler; + std::function()>& mutator_sampler = data.mutator_sampler; + Schedule& result = next_population.at(trace_id); + int sampled_trace_id = -1; + // Loop until success + for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) { + sampled_trace_id = trace_sampler(); + tir::Trace trace = population.at(sampled_trace_id)->trace().value(); + if (Optional opt_mutator = mutator_sampler()) { + // Decision: mutate + Mutator mutator = opt_mutator.value(); + if (Optional new_trace = mutator->Apply(trace, rand_state)) { + if (Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { + // note that sch's trace is different from new_trace + // because it contains post-processing information + result = sch.value(); + break; + } + } + } else if (cbmask.QueryAndMark(sampled_trace_id)) { + // Decision: do not mutate + break; + } + } + // if retry count exceeds the limit, reuse an old sample + if (!result.defined()) { + result = population.at(sampled_trace_id); + } + }; + support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate); + population.swap(next_population); + LOG(INFO) << "Evolve iter #" << iter << " done. Summary:\n" << pp.SummarizeFailures(); + } + // Return the best states from the heap, sorting from higher score to lower ones + std::sort(heap.heap.begin(), heap.heap.end()); + std::vector results; + results.reserve(num); + for (const SizedHeap::Item& item : heap.heap) { + results.push_back(item.sch); + } + + constexpr int kNumScoresPerLine = 16; + std::ostringstream os; + int n = heap.heap.size(); + for (int st = 0; st < n; st += kNumScoresPerLine) { + os << std::endl; + int ed = std::min(st + kNumScoresPerLine, n); + os << "[" << (st + 1) << " : " << ed << "]:\t"; + for (int i = st; i < ed; ++i) { + if (i != st) { + os << " "; + } + os << std::fixed << std::setprecision(4) << heap.heap.at(i).score; + } + } + LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); + return results; +} + +std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( + const std::vector& unmeasured, const std::vector& bests, int num) { + int num_rands = num * self->eps_greedy; + int num_bests = num - num_rands; + std::vector rands = + tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); + std::vector results; + results.reserve(num); + for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { + bool has_best = i_bests < static_cast(bests.size()); + bool has_rand = i_rands < static_cast(rands.size()); + // Pick a schedule + Schedule sch{nullptr}; + // If needs `bests`, then prefer `bests` + if (i < num_bests) { + if (has_best) { + sch = bests[i_bests++]; + } else if (has_rand) { + sch = unmeasured[rands[i_rands++]]; + } else { + break; + } + } else { + // Else prefer `rands` + if (has_rand) { + sch = unmeasured[rands[i_rands++]]; + } else if (has_best) { + sch = bests[i_bests++]; + } else { + break; + } + } + results.push_back(sch); + } + return results; +} + +Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + int sample_num = self->num_trials_per_iter; + if (ed > self->num_trials_total) { + sample_num = self->num_trials_total - st; + ed = self->num_trials_total; + } + ICHECK_LT(st, ed); + int pop = self->population_size; + std::vector inits; + inits.reserve(pop); + + LOG(INFO) << "Generating candidates......"; + std::vector measured = PickBestFromDatabase(pop * self->init_measured_ratio); + LOG(INFO) << "Picked top " << measured.size() << " candidate(s) from database"; + std::vector unmeasured = SampleInitPopulation(pop - measured.size()); + LOG(INFO) << "Sampled " << unmeasured.size() << " candidate(s)"; + inits.insert(inits.end(), measured.begin(), measured.end()); + inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); + ICHECK_EQ(inits.size(), self->population_size); + std::vector bests = EvolveWithCostModel(inits, sample_num); + LOG(INFO) << "Got " << bests.size() << " candidate(s) with evolutionary search"; + std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); + LOG(INFO) << "Sending " << picks.size() << " candidates(s) for measurement"; + return AssembleCandidates(picks, self->args_info_); +} + +void EvolutionarySearchNode::State::NotifyRunnerResults( + const TuneContext& context, const Array& measure_candidates, + const Array& results) { + st += results.size(); + ed += results.size(); +} + +SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population_size, // + double init_measured_ratio, // + int init_max_fail_count, // + int genetic_num_iters, // + double genetic_mutate_prob, // + int genetic_max_fail_count, // + double eps_greedy) { + TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); + TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability"); + TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + n->population_size = population_size; + n->init_measured_ratio = init_measured_ratio; + n->init_max_fail_count = init_max_fail_count; + n->genetic_num_iters = genetic_num_iters; + n->genetic_max_fail_count = genetic_max_fail_count; + n->genetic_mutate_prob = genetic_mutate_prob; + n->eps_greedy = eps_greedy; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") + .set_body_typed(SearchStrategy::EvolutionarySearch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 212e53aa500ff..45efd9f76cefa 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -36,6 +36,15 @@ namespace tir { */ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive); +/*! + * \brief Sample k random integers from given range without replacement, i.e, no duplication. + * \param rand_state The pointer to schedule's random state + * \param n The range is defined as 0 to n-1. + * \param k The total number of samples. + * \return The randomly selected samples from the n candidates. + */ +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k); /*! * \brief Sample once category from candidates according to the probability weights. * \param rand_state The pointer to schedule's random state @@ -47,6 +56,14 @@ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_st TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); +/*! + * \brief Create a sampling function that does multinomial sampling. + * \param rand_state The random state. + * \param weights The weights for multinomial sampling. + * \return The multinomial sampling function. + */ +TVM_DLL std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); /*! * \brief Sample the factors to perfect tile a specific loop * \param rand_state The random state diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 171838572dbb8..83ef1e20be606 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -187,6 +187,28 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights) { + ICHECK(!weights.empty()); + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(), + dist = std::uniform_real_distribution(0.0, sum), + sums = std::move(sums)]() mutable -> int32_t { + support::LinearCongruentialEngine rand_(&rng); + double p = dist(rand_); + int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int32_t n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; +} + std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits) { CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index a4d32175eb0bd..b16eab7123753 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -17,16 +17,27 @@ """ Test Meta Schedule SearchStrategy """ # pylint: disable=missing-function-docstring import sys +from typing import List, Optional, Tuple, Union + +import numpy as np import pytest import tvm +from tvm.ir import IRModule from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.builder import LocalBuilder +from tvm.meta_schedule.cost_model import PyCostModel +from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload +from tvm.meta_schedule.mutator.mutator import PyMutator +from tvm.meta_schedule.runner import LocalRunner, RunnerResult from tvm.meta_schedule.search_strategy import ( + EvolutionarySearch, + MeasureCandidate, ReplayFunc, ReplayTrace, SearchStrategy, ) from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.task_scheduler import RoundRobin from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -80,11 +91,11 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl num_trials_total = 20 strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) - tune_context.space_generator.initialize_with_tune_context(tune_context) - spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + context.space_generator.initialize_with_tune_context(context) + spaces = context.space_generator.generate_design_space(context.mod) - strategy.initialize_with_tune_context(tune_context) + strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] @@ -99,11 +110,161 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(tune_context, candidates, runner_results) + strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [7, 7, 6] +def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name + class DummyMutator(PyMutator): + """Dummy Mutator for testing""" + + def initialize_with_tune_context(self, context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + class DummyDatabase(PyDatabase): + """Dummy Database for testing""" + + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + class RandomModel(PyCostModel): + """Random cost model for testing""" + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, path: str) -> None: + self.random_state = tuple(np.load(path, allow_pickle=True)) + + def save(self, path: str) -> None: + np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + np.random.set_state(self.random_state) + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result + + num_trials_per_iter = 10 + num_trials_total = 100 + + strategy = EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + num_trials_total=num_trials_total, + population_size=5, + init_measured_ratio=0.1, + init_max_fail_count=10, + genetic_num_iters=3, + genetic_mutate_prob=0.5, + genetic_max_fail_count=10, + eps_greedy=0.9, + ) + context = TuneContext( + mod=Matmul, + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + mutator_probs={ + DummyMutator(): 1.0, + }, + target=tvm.target.Target("llvm"), + num_threads=1, # because we are using a mutator from the python side + ) + _scheduler = RoundRobin( + tasks=[context], + builder=LocalBuilder(), + runner=LocalRunner(), + database=DummyDatabase(), + cost_model=RandomModel(), + measure_callbacks=[], + ) + context.space_generator.initialize_with_tune_context(context) + spaces = context.space_generator.generate_design_space(context.mod) + + strategy.initialize_with_tune_context(context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(context, candidates, runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + print(num_trials_each_iter) + correct_count = 10 # For each iteration except the last one + assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + ( + [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else [] + ) + del _scheduler + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))