From 14632f0424e818a485f4c9c89bde4fdc614483bc Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 18:44:53 -0700 Subject: [PATCH] [M3a] Sampling Primitives & Random Number Generator (#421) --- include/tvm/support/random_engine.h | 121 ++++ include/tvm/tir/schedule/schedule.h | 14 +- python/tvm/meta_schedule/utils.py | 12 +- src/meta_schedule/autotune.cc | 8 +- src/meta_schedule/autotune.h | 5 +- .../cost_model/rand_cost_model.cc | 17 +- src/meta_schedule/schedule.h | 3 - src/meta_schedule/search.cc | 53 +- src/meta_schedule/search.h | 12 +- src/meta_schedule/space/post_order_apply.cc | 26 +- src/meta_schedule/space/postproc.cc | 34 +- src/meta_schedule/space/postproc.h | 4 +- src/meta_schedule/space/schedule_fn.cc | 22 +- src/meta_schedule/space/search_rule.cc | 2 +- src/meta_schedule/strategy/evolutionary.cc | 141 +++-- src/meta_schedule/strategy/mutator.cc | 62 +- src/meta_schedule/strategy/mutator.h | 4 +- src/meta_schedule/strategy/replay.cc | 18 +- src/tir/schedule/concrete_schedule.cc | 18 +- src/tir/schedule/concrete_schedule.h | 17 +- src/tir/schedule/primitive.h | 105 +++- src/tir/schedule/primitive/sampling.cc | 560 ++++++++++++++++- src/tir/schedule/sampler.cc | 592 ------------------ src/tir/schedule/sampler.h | 151 ----- src/tir/schedule/traced_schedule.cc | 23 +- src/tir/schedule/traced_schedule.h | 4 +- tests/cpp/meta_schedule_test.cc | 37 ++ tests/cpp/random_engine_test.cc | 71 +++ ...test_meta_schedule_bsr_sparse_dense_cpu.py | 48 +- .../test_meta_schedule_feature.py | 2 +- ...st_meta_schedule_layout_rewrite_network.py | 4 +- .../test_resnet_end_to_end_cuda.py | 13 +- 32 files changed, 1183 insertions(+), 1020 deletions(-) create mode 100644 include/tvm/support/random_engine.h delete mode 100644 src/tir/schedule/sampler.cc delete mode 100644 src/tir/schedule/sampler.h create mode 100644 tests/cpp/meta_schedule_test.cc create mode 100644 tests/cpp/random_engine_test.cc diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h new file mode 100644 index 0000000000..9713b69d42 --- /dev/null +++ b/include/tvm/support/random_engine.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file random_engine.h + * \brief Random number generator, for Sampling functions. + */ + +#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ +#define TVM_SUPPORT_RANDOM_ENGINE_H_ + +#include + +#include // for uint64_t + +namespace tvm { +namespace support { + +/*! + * \brief This linear congruential engine is a drop-in replacement for std::minstd_rand. It strictly + * corresponds to std::minstd_rand and is designed to be platform-independent. + * \note Our linear congruential engine is a complete implementation of + * std::uniform_random_bit_generator so it can be used as generator for any STL random number + * distribution. However, parts of std::linear_congruential_engine's member functions are not + * included for simplification. For full member functions of std::minstd_rand, please check out the + * following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine + */ +class LinearCongruentialEngine { + public: + /*! + * \brief The result type is defined as int64_t here to avoid overflow. + * \note The type name is not in Google style because it is used in STL's distribution inferface. + */ + using result_type = uint64_t; + using TRandState = int64_t; + + /*! \brief The multiplier */ + static constexpr TRandState multiplier = 48271; + + /*! \brief The increment */ + static constexpr TRandState increment = 0; + + /*! \brief The modulus */ + static constexpr TRandState modulus = 2147483647; + + /*! + * \brief The minimum possible value of random state here. + * \note The function name is uncapilized because it is used in STL's distribution inferface. + */ + static constexpr result_type min() { return 0; } + + /*! + * \brief The maximum possible value of random state here. + * \note The function name is uncapilized because it is used in STL's distribution inferface. + */ + static constexpr result_type max() { return modulus - 1; } + + /*! + * \brief Operator to move the random state to the next and return the new random state. According + * to definition of linear congruential engine, the new random state value is computed as + * new_random_state = (current_random_state * multiplier + increment) % modulus. + * \return The next current random state value in the type of result_type. + * \note In order for better efficiency, the implementation here has a few assumptions: + * 1. The multiplication and addition won't overflow. + * 2. The given random state pointer `rand_state_ptr` is not nullptr. + * 3. The given random state `*(rand_state_ptr)` is in the range of [0, modulus - 1]. + */ + result_type operator()() { + (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus; + return *rand_state_ptr_; + } + + /*! + * \brief Change the start random state of RNG with the seed of a new random state value. + * \param rand_state The random state given in result_type. + */ + void Seed(TRandState rand_state = 1) { + rand_state %= modulus; // Make sure the seed is within the range of modulus. + if (rand_state == 0) + rand_state = 1; // Avoid getting all 0 given the current parameter set. + else if (rand_state < 0) + rand_state += modulus; // Make sure the rand state is non-negative. + ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. + *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. + } + + /*! + * \brief Construct a random number generator with a random state pointer. + * \param rand_state_ptr The random state pointer given in result_type*. + * \note The random state is not checked for whether it's nullptr and whether it's in the range of + * [0, modulus-1]. We assume the given random state is valid or the Seed function would be + * called right after the constructor before any usage. + */ + explicit LinearCongruentialEngine(TRandState* rand_state_ptr) { + rand_state_ptr_ = rand_state_ptr; + } + + private: + TRandState* rand_state_ptr_; +}; + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_RANDOM_ENGINE_H_ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9148a34760..02829f37fd 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -19,11 +19,15 @@ #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ #define TVM_TIR_SCHEDULE_SCHEDULE_H_ +#include #include namespace tvm { namespace tir { +using TRandState = support::LinearCongruentialEngine::TRandState; +using RandEngine = support::LinearCongruentialEngine; + /*! \brief The level of detailed error message rendering */ enum class ScheduleErrorRenderLevel : int32_t { /*! \brief Render a detailed error message */ @@ -113,12 +117,12 @@ class ScheduleNode : public runtime::Object { * 3) All the random variables are valid in the copy, pointing to the correpsonding sref * reconstructed */ - virtual Schedule Copy(int64_t seed = -1) const = 0; + virtual Schedule Copy(tir::TRandState seed = -1) const = 0; /*! * \brief Seed the randomness * \param seed The new random seed, -1 if use device random, otherwise non-negative */ - virtual void Seed(int64_t seed = -1) = 0; + virtual void Seed(tir::TRandState seed = -1) = 0; /*! \brief Fork the random state */ virtual int64_t ForkSeed() = 0; @@ -502,11 +506,11 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int64_t seed, int debug_mode, + TVM_DLL static Schedule Concrete(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); - TVM_DLL static Schedule Meta(IRModule mod, int64_t seed, int debug_mode, + TVM_DLL static Schedule Meta(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); - TVM_DLL static Schedule Traced(IRModule mod, int64_t seed, int debug_mode, + TVM_DLL static Schedule Traced(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 294098cf50..83a46e6cc0 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -46,7 +46,7 @@ def make_error_msg() -> str: - """ Get the error message from traceback. """ + """Get the error message from traceback.""" error_msg = str(traceback.format_exc()) if len(error_msg) > MAX_ERROR_MSG_LEN: error_msg = ( @@ -408,7 +408,7 @@ def local_builder_worker( timeout: int, verbose: int, ) -> BuildResult.TYPE: - """ Local worker for ProgramBuilder """ + """Local worker for ProgramBuilder""" # deal with build_func build_func = { "tar": build_func_tar.tar, # export to tar @@ -603,7 +603,7 @@ def rpc_runner_worker( f_create_args: Callable[[Device], List[NDArray]], verbose: int, ) -> MeasureResult.TYPE: - """ RPC worker for ProgramRunner """ + """RPC worker for ProgramRunner""" measure_input = measure_inputs[index] build_result = build_results[index] @@ -653,13 +653,13 @@ def timed_func(): else: rpc_eval_repeat = 1 if f_create_args is not None: - args_set = [f_create_args(ctx) for _ in range(rpc_eval_repeat)] + args_set = [f_create_args(dev) for _ in range(rpc_eval_repeat)] else: args_set = [ - realize_arguments(remote, ctx, measure_input.sch.mod["main"]) + realize_arguments(remote, dev, measure_input.sch.mod["main"]) for _ in range(rpc_eval_repeat) ] - ctx.sync() + dev.sync() costs = sum([time_f(*args).results for args in args_set], ()) # clean up remote files remote.remove(build_result.filename) diff --git a/src/meta_schedule/autotune.cc b/src/meta_schedule/autotune.cc index a2ce7fb4a0..3ad73fd99f 100644 --- a/src/meta_schedule/autotune.cc +++ b/src/meta_schedule/autotune.cc @@ -24,8 +24,10 @@ namespace tvm { namespace meta_schedule { void TuneContextNode::Init(Optional seed) { - if (seed.defined()) { - this->sampler.Seed(seed.value()->value); + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&this->rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&this->rand_state).Seed(std::random_device()()); } if (task.defined()) { task.value()->Init(this); @@ -59,7 +61,7 @@ void TuneContextNode::Init(Optional seed) { bool TuneContextNode::Postprocess(const Schedule& sch) { sch->EnterPostproc(); for (const Postproc& postproc : postprocs) { - if (!postproc->Apply(task.value(), sch, &sampler)) { + if (!postproc->Apply(task.value(), sch, &rand_state)) { return false; } } diff --git a/src/meta_schedule/autotune.h b/src/meta_schedule/autotune.h index dc7c391f50..92ae034a53 100644 --- a/src/meta_schedule/autotune.h +++ b/src/meta_schedule/autotune.h @@ -43,7 +43,8 @@ class TuneContextNode : public runtime::Object { Array postprocs; Array measure_callbacks; int num_threads; - Sampler sampler; + + tir::TRandState rand_state; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("task", &task); @@ -56,7 +57,6 @@ class TuneContextNode : public runtime::Object { v->Visit("postprocs", &postprocs); v->Visit("measure_callbacks", &measure_callbacks); v->Visit("num_threads", &num_threads); - // `sampler` is not visited } void Init(Optional seed = NullOpt); @@ -93,7 +93,6 @@ class TuneContext : public runtime::ObjectRef { n->postprocs = postprocs; n->measure_callbacks = measure_callbacks; n->num_threads = num_threads; - // `n->sampler` is not initialized data_ = std::move(n); (*this)->Init(seed); } diff --git a/src/meta_schedule/cost_model/rand_cost_model.cc b/src/meta_schedule/cost_model/rand_cost_model.cc index 697e9aefe2..dba3c1b93d 100644 --- a/src/meta_schedule/cost_model/rand_cost_model.cc +++ b/src/meta_schedule/cost_model/rand_cost_model.cc @@ -27,12 +27,10 @@ namespace meta_schedule { /*! \brief The cost model returning random value for all predictions */ class RandCostModelNode : public CostModelNode { public: - /*! \brief A sampler for generating random numbers */ - Sampler sampler; + /*! \brief A random state for sampling functions to generate random numbers */ + tir::TRandState rand_state; - void VisitAttrs(tvm::AttrVisitor* v) { - // sampler is not visited - } + void VisitAttrs(tvm::AttrVisitor* v) {} /*! * \brief Update the cost model according to new measurement results (training data). @@ -48,7 +46,7 @@ class RandCostModelNode : public CostModelNode { * \return The predicted scores for all states */ std::vector Predict(const SearchTask& task, const Array& states) override { - return sampler.SampleUniform(states.size(), 0.0, 1.0); + return tir::SampleUniform(&rand_state, states.size(), 0.0, 1.0); } static constexpr const char* _type_key = "meta_schedule.RandCostModel"; @@ -61,11 +59,10 @@ class RandCostModelNode : public CostModelNode { */ class RandCostModel : public CostModel { public: - RandCostModel() { data_ = make_object(); } - - explicit RandCostModel(int seed) { + explicit RandCostModel(int seed = -1) { ObjectPtr n = make_object(); - n->sampler.Seed(seed); + if (seed == -1) seed = std::random_device()(); + tir::RandEngine(&n->rand_state).Seed(seed); data_ = std::move(n); } diff --git a/src/meta_schedule/schedule.h b/src/meta_schedule/schedule.h index 96844ea325..cf94400ca6 100644 --- a/src/meta_schedule/schedule.h +++ b/src/meta_schedule/schedule.h @@ -22,14 +22,11 @@ #include #include -#include "../tir/schedule/sampler.h" - namespace tvm { namespace meta_schedule { using ScheduleNode = tir::TraceNode; using Schedule = tir::Schedule; -using Sampler = tir::Sampler; using BlockRV = tir::BlockRV; using BlockRVNode = tir::BlockRVNode; using LoopRV = tir::LoopRV; diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index 75f2b7acdd..e4b161049c 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -58,17 +58,20 @@ SearchTask::SearchTask(tir::PrimFunc workload, String task_name, Target target, */ TVM_DLL Optional AutoTune(SearchTask task, SearchSpace space, SearchStrategy strategy, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } + if (verbose) { LOG(INFO) << "Tuning for task: " << task; } space->Init(task); strategy->Init(task); measurer->Init(task); - return strategy->Search(task, space, measurer, &seeded, verbose); + return strategy->Search(task, space, measurer, &rand_state, verbose); } /********** Printer **********/ @@ -101,17 +104,19 @@ struct Internal { * \brief Apply postprocessors onto the schedule * \param space The search space * \param sch The schedule to be postprocessed - * \param sampler The random number generator + * \param rand_state The random state for sampling * \return Whether postprocessing has succeeded * \sa SearchSpaceNode::Postprocess */ static bool SearchSpacePostprocess(SearchSpace space, SearchTask task, Schedule sch, Optional seed) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return space->Postprocess(task, sch, &seeded); + return space->Postprocess(task, sch, &rand_state); } /*! * \brief Sample a schedule out of the search space, calls SearchSpaceNode::SampleSchedule @@ -122,11 +127,13 @@ struct Internal { */ static Schedule SearchSpaceSampleSchedule(SearchSpace space, SearchTask task, Optional seed) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return space->SampleSchedule(task, &seeded); + return space->SampleSchedule(task, &rand_state); } /*! * \brief Get support of the search space, calls SearchSpaceNode::GetSupport @@ -138,11 +145,13 @@ struct Internal { */ static Array SearchSpaceGetSupport(SearchSpace space, SearchTask task, Optional seed) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return space->GetSupport(task, &seeded); + return space->GetSupport(task, &rand_state); } /*! * \brief Explore the search space and find the best schedule @@ -156,11 +165,13 @@ struct Internal { static Optional SearchStrategySearch(SearchStrategy strategy, SearchTask task, SearchSpace space, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return strategy->Search(task, space, measurer, &seeded, verbose); + return strategy->Search(task, space, measurer, &rand_state, verbose); } }; diff --git a/src/meta_schedule/search.h b/src/meta_schedule/search.h index 7e2edfc06e..d184a0b48a 100644 --- a/src/meta_schedule/search.h +++ b/src/meta_schedule/search.h @@ -21,6 +21,7 @@ #include +#include "../tir/schedule/primitive.h" #include "./schedule.h" namespace tvm { @@ -101,22 +102,23 @@ class SearchSpaceNode : public runtime::Object { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param sampler The random number generator + * \param rand_state The random state for sampling */ - virtual bool Postprocess(const SearchTask& task, const Schedule& sch, Sampler* sampler) = 0; + virtual bool Postprocess(const SearchTask& task, const Schedule& sch, + tir::TRandState* rand_state) = 0; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - virtual Schedule SampleSchedule(const SearchTask& task, Sampler* sampler) = 0; + virtual Schedule SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) = 0; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return The support of the search space. Any point from the search space should along to one of * the traces returned */ - virtual Array GetSupport(const SearchTask& task, Sampler* sampler) = 0; + virtual Array GetSupport(const SearchTask& task, tir::TRandState* rand_state) = 0; static constexpr const char* _type_key = "meta_schedule.SearchSpace"; TVM_DECLARE_BASE_OBJECT_INFO(SearchSpaceNode, Object); @@ -156,7 +158,7 @@ class SearchStrategyNode : public Object { * \return The best schedule found, NullOpt if no valid schedule is found */ virtual Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, + const ProgramMeasurer& measurer, tir::TRandState* rand_state, int verbose) = 0; /*! \brief Explore the search space */ virtual void Search() { LOG(FATAL) << "NotImplemented"; } diff --git a/src/meta_schedule/space/post_order_apply.cc b/src/meta_schedule/space/post_order_apply.cc index 7d32178893..58a05658e8 100644 --- a/src/meta_schedule/space/post_order_apply.cc +++ b/src/meta_schedule/space/post_order_apply.cc @@ -49,22 +49,23 @@ class PostOrderApplyNode : public SearchSpaceNode { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param sampler The random number generator + * \param rand_state The random state for sampling */ - bool Postprocess(const SearchTask& task, const Schedule& sch, Sampler* sampler) override; + bool Postprocess(const SearchTask& task, const Schedule& sch, + tir::TRandState* rand_state) override; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - Schedule SampleSchedule(const SearchTask& task, Sampler* sampler) override; + Schedule SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) override; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return An array with a single element returned from SampleSchedule * \sa PostOrderApplyNode::SampleSchedule */ - Array GetSupport(const SearchTask& task, Sampler* sampler) override; + Array GetSupport(const SearchTask& task, tir::TRandState* rand_state) override; static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SearchSpaceNode); @@ -97,20 +98,20 @@ PostOrderApply::PostOrderApply(Array stages, Array postpro /********** Sampling **********/ bool PostOrderApplyNode::Postprocess(const SearchTask& task, const Schedule& sch, - Sampler* sampler) { + tir::TRandState* rand_state) { sch->EnterPostproc(); for (const Postproc& postproc : postprocs) { - if (!postproc->Apply(task, sch, sampler)) { + if (!postproc->Apply(task, sch, rand_state)) { return false; } } return true; } -Schedule PostOrderApplyNode::SampleSchedule(const SearchTask& task, Sampler* sampler) { - Array support = GetSupport(task, sampler); +Schedule PostOrderApplyNode::SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) { + Array support = GetSupport(task, rand_state); ICHECK(!support.empty()) << "ValueError: Found null support"; - int i = sampler->SampleInt(0, support.size()); + int i = tir::SampleInt(rand_state, 0, support.size()); return support[i]; } @@ -146,12 +147,13 @@ class BlockCollector : public tir::StmtVisitor { const tir::BlockNode* root_block_; }; -Array PostOrderApplyNode::GetSupport(const SearchTask& task, Sampler* sampler) { +Array PostOrderApplyNode::GetSupport(const SearchTask& task, + tir::TRandState* rand_state) { using ScheduleAndUnvisitedBlocks = std::pair>; Array curr{ Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/sampler->ForkSeed(), + /*seed=*/tir::ForkSeed(rand_state), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail)}; for (const SearchRule& rule : stages) { @@ -201,7 +203,7 @@ Array PostOrderApplyNode::GetSupport(const SearchTask& task, Sampler* Trace trace = sch->trace().value()->Simplified(/*remove_postproc=*/true); Schedule new_sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/sampler->ForkSeed(), + /*seed=*/tir::ForkSeed(rand_state), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); trace->ApplyToSchedule(new_sch, /*remove_postproc=*/true); diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index 4e9c229160..c0e1c599f6 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -38,8 +38,8 @@ Postproc::Postproc(String name, FProc proc) { /********** Postproc **********/ -bool PostprocNode::Apply(const SearchTask& task, const Schedule& sch, Sampler* sampler) { - return proc_(task, sch, sampler); +bool PostprocNode::Apply(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state) { + return proc_(task, sch, rand_state); } /********** RewriteTensorize **********/ @@ -115,7 +115,7 @@ class PostprocRewriteTensorize { Postproc RewriteTensorize(Array tensor_intrins) { auto f_proc = [tensor_intrins{std::move(tensor_intrins)}](SearchTask task, Schedule self, - void* _sampler) -> bool { + void* _rand_state) -> bool { return PostprocRewriteTensorize(tensor_intrins).Proc(self); }; return Postproc("rewrite_tensorize", f_proc); @@ -181,7 +181,7 @@ class PostprocRewriteCooperativeFetch { }; Postproc RewriteCooperativeFetch() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocRewriteCooperativeFetch().Proc(sch); }; return Postproc("rewrite_cooperative_fetch", f_proc); @@ -498,7 +498,7 @@ class PostprocRewriteParallelizeVectorizeUnroll { }; Postproc RewriteParallelizeVectorizeUnroll() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocRewriteParallelizeVectorizeUnroll().Proc(sch); }; return Postproc("rewrite_parallelize_vectorize_unroll", f_proc); @@ -632,7 +632,7 @@ class PostprocRewriteUnboundBlocks { }; Postproc RewriteUnboundBlocks() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocRewriteUnboundBlocks().Proc(task, sch); }; return Postproc("rewrite_unbound_blocks", f_proc); @@ -764,7 +764,7 @@ class PostprocRewriteReductionBlock { }; Postproc RewriteReductionBlock() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocRewriteReductionBlock().Proc(sch); }; return Postproc("rewrite_reduction_block", f_proc); @@ -794,7 +794,7 @@ class PostprocDisallowDynamicLoops { }; Postproc DisallowDynamicLoops() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocDisallowDynamicLoops().Proc(sch); }; return Postproc("disallow_dynamic_loops", f_proc); @@ -849,7 +849,7 @@ class PostprocVerifyGPUCode { }; Postproc VerifyGPUCode() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocVerifyGPUCode().Proc(task, sch); }; return Postproc("verify_gpu_code", f_proc); @@ -1075,8 +1075,8 @@ class PostProcRewriteLayout { } // Step 1: create a new buffer tir::Buffer new_buffer(buffer->data, buffer->dtype, new_shape, Array(), - buffer->elem_offset, buffer->name, - buffer->data_alignment, buffer->offset_factor, buffer->buffer_type); + buffer->elem_offset, buffer->name, buffer->data_alignment, + buffer->offset_factor, buffer->buffer_type); // Step 2: do the rewrite to the buffer access // the rule is as below: // for example, @@ -1104,7 +1104,7 @@ class PostProcRewriteLayout { }; Postproc RewriteLayout() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostProcRewriteLayout().Proc(sch, task); }; return Postproc("rewrite_layout", f_proc); @@ -1118,11 +1118,13 @@ struct Internal { * \sa PostProcNode::Apply */ static bool Apply(Postproc self, SearchTask task, Schedule sch, Optional seed) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return self->Apply(task, sch, &seeded); + return self->Apply(task, sch, &rand_state); } }; diff --git a/src/meta_schedule/space/postproc.h b/src/meta_schedule/space/postproc.h index 2e9bd604fe..770559820f 100644 --- a/src/meta_schedule/space/postproc.h +++ b/src/meta_schedule/space/postproc.h @@ -44,10 +44,10 @@ class PostprocNode : public Object { /*! * \brief Apply the postprocessor * \param sch The schedule to be processed - * \param sampler The random number sampler + * \param rand_state The random state for sampling * \return If the post-processing succeeds */ - bool Apply(const SearchTask& task, const Schedule& sch, Sampler* sampler); + bool Apply(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state); static constexpr const char* _type_key = "meta_schedule.Postproc"; TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); diff --git a/src/meta_schedule/space/schedule_fn.cc b/src/meta_schedule/space/schedule_fn.cc index 5d056c328f..c009f1a09c 100644 --- a/src/meta_schedule/space/schedule_fn.cc +++ b/src/meta_schedule/space/schedule_fn.cc @@ -47,22 +47,23 @@ class ScheduleFnNode : public SearchSpaceNode { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param sampler The random number generator + * \param rand_state The random state for sampling */ - bool Postprocess(const SearchTask& task, const Schedule& sch, Sampler* sampler) override; + bool Postprocess(const SearchTask& task, const Schedule& sch, + tir::TRandState* rand_state) override; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - Schedule SampleSchedule(const SearchTask& task, Sampler* sampler) override; + Schedule SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) override; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return An array with a single element returned from SampleSchedule * \sa ScheduleFnNode::SampleSchedule */ - Array GetSupport(const SearchTask& task, Sampler* sampler) override; + Array GetSupport(const SearchTask& task, tir::TRandState* rand_state) override; static constexpr const char* _type_key = "meta_schedule.ScheduleFn"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SearchSpaceNode); @@ -94,27 +95,28 @@ ScheduleFn::ScheduleFn(PackedFunc sch_fn, Array postprocs) { /********** Sampling **********/ -bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, Sampler* sampler) { +bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, + tir::TRandState* rand_state) { sch->EnterPostproc(); for (const Postproc& postproc : postprocs) { - if (!postproc->Apply(task, sch, sampler)) { + if (!postproc->Apply(task, sch, rand_state)) { return false; } } return true; } -Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, Sampler* sampler) { +Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) { Schedule sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/sampler->ForkSeed(), + /*seed=*/tir::ForkSeed(rand_state), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); this->sch_fn_(sch); return sch; } -Array ScheduleFnNode::GetSupport(const SearchTask& task, Sampler* sampler) { - return {SampleSchedule(task, sampler)}; +Array ScheduleFnNode::GetSupport(const SearchTask& task, tir::TRandState* rand_state) { + return {SampleSchedule(task, rand_state)}; } /********** FFI **********/ diff --git a/src/meta_schedule/space/search_rule.cc b/src/meta_schedule/space/search_rule.cc index 406740cbb3..b481d36c46 100644 --- a/src/meta_schedule/space/search_rule.cc +++ b/src/meta_schedule/space/search_rule.cc @@ -27,8 +27,8 @@ namespace tvm { namespace meta_schedule { /**************** TIR Nodes ****************/ -using tir::ForNode; using tir::BlockNode; +using tir::ForNode; /********** Constructors **********/ diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index df1cd19148..9b1c9c6357 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -134,12 +134,12 @@ class EvolutionaryNode : public SearchStrategyNode { * \param task The search task * \param space The search space * \param measurer The measurer that builds, runs and profiles sampled programs - * \param sampler The random number sampler + * \param rand_state The random state for sampling * \param verbose Whether or not in verbose mode * \return The best schedule found, NullOpt if no valid schedule is found */ Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, + const ProgramMeasurer& measurer, tir::TRandState* rand_state, int verbose) override; /********** Stages in evolutionary search **********/ @@ -151,22 +151,22 @@ class EvolutionaryNode : public SearchStrategyNode { * \param support The support to be sampled from * \param task The search task * \param space The search space - * \param sampler The random number sampler + * \param rand_state The random state for sampling * \return The generated samples, all of which are not post-processed */ Array SampleInitPopulation(const Array& support, const SearchTask& task, - const SearchSpace& space, Sampler* sampler); + const SearchSpace& space, tir::TRandState* rand_state); /*! * \brief Perform evolutionary search using genetic algorithm with the cost model * \param inits The initial population * \param task The search task * \param space The search space - * \param sampler The random number sampler + * \param rand_state The random state for sampling * \return An array of schedules, the sampling result */ Array EvolveWithCostModel(const Array& inits, const SearchTask& task, - const SearchSpace& space, Sampler* sampler); + const SearchSpace& space, tir::TRandState* rand_state); /*! * \brief Pick a batch of samples for measurement with epsilon greedy @@ -174,12 +174,12 @@ class EvolutionaryNode : public SearchStrategyNode { * \param bests The best populations according to the cost model when picking top states * \param task The search task * \param space The search space - * \param sampler The random number sampler + * \param rand_state The random state for sampling * \return A list of schedules, result of epsilon-greedy sampling */ Array PickWithEpsGreedy(const Array& inits, const Array& bests, const SearchTask& task, const SearchSpace& space, - Sampler* sampler); + tir::TRandState* rand_state); /*! * \brief Make measurements and update the cost model @@ -200,16 +200,16 @@ class EvolutionaryNode : public SearchStrategyNode { friend class Evolutionary; /*! - * \brief Fork a sampler into `n` samplers - * \param n The number of samplers to be forked - * \param sampler The sampler to be forked - * \return A list of samplers, the result of forking + * \brief Fork a random state into `n` random states + * \param n The number of random states to be forked + * \param rand_state The random state for sampling + * \return A list of random states, the result of forking */ - static std::vector ForkSamplers(int n, Sampler* sampler) { - std::vector result; + static std::vector ForkRandStates(int n, tir::TRandState* rand_state) { + std::vector result; result.reserve(n); for (int i = 0; i < n; ++i) { - result.emplace_back(sampler->ForkSeed()); + result.emplace_back(tir::ForkSeed(rand_state)); } return result; } @@ -226,19 +226,16 @@ class EvolutionaryNode : public SearchStrategyNode { /*! * \brief Replay the trace and do postprocessing - * \param n The number of samplers to be forked - * \param sampler The sampler to be forked - * \return A list of samplers, the result of forking */ static Optional ReplayTrace(const Trace& trace, const SearchTask& task, - const SearchSpace& space, Sampler* sampler, + const SearchSpace& space, tir::TRandState* rand_state, const tir::PrimFunc& workload) { Schedule sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), workload}}), - /*seed=*/sampler->ForkSeed(), + /*seed=*/tir::ForkSeed(rand_state), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); trace->ApplyToSchedule(sch, /*remove_postproc=*/true); - if (!space->Postprocess(task, sch, sampler)) { + if (!space->Postprocess(task, sch, rand_state)) { return NullOpt; } return sch; @@ -246,11 +243,11 @@ class EvolutionaryNode : public SearchStrategyNode { /*! * \brief Create a sampler function that picks mutators according to the mass function - * \param sampler The source of randomness + * \param rand_state The random state for sampling * \return The sampler created */ static std::function()> MakeMutatorSampler( - double p_mutate, const Map& mutator_probs, Sampler* sampler) { + double p_mutate, const Map& mutator_probs, tir::TRandState* rand_state) { CHECK(0.0 <= p_mutate && p_mutate <= 1.0) // << "ValueError: Probability should be within [0, 1], " << "but get `p_mutate = " << p_mutate << '\''; @@ -279,7 +276,7 @@ class EvolutionaryNode : public SearchStrategyNode { masses[i] /= total_mass_mutator; } } - auto idx_sampler = sampler->MakeMultinomial(masses); + auto idx_sampler = tir::MakeMultinomial(rand_state, masses); return [idx_sampler = std::move(idx_sampler), mutators = std::move(mutators)]() -> Optional { int i = idx_sampler(); @@ -312,7 +309,6 @@ class EvolutionaryNode : public SearchStrategyNode { * \param candidates The candidates for prediction * \param task The search task * \param space The search space - * \param sampler Source of randomness * \return The normalized throughput in the prediction */ std::vector PredictNormalizedThroughput(const std::vector& candidates, @@ -427,8 +423,8 @@ Evolutionary::Evolutionary(int total_measures, int num_measures_per_iteration, i CHECK_LE(num_measures_per_iteration, population) << "ValueError: requires `num_measures_per_iteration <= population`"; { - Sampler sampler(42); - EvolutionaryNode::MakeMutatorSampler(p_mutate, mutator_probs, &sampler); + tir::TRandState rand_state = 42; + EvolutionaryNode::MakeMutatorSampler(p_mutate, mutator_probs, &rand_state); } ObjectPtr n = make_object(); n->total_measures = total_measures; @@ -447,23 +443,23 @@ Evolutionary::Evolutionary(int total_measures, int num_measures_per_iteration, i /********** Search **********/ Optional EvolutionaryNode::Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, - int verbose) { - Array support = space->GetSupport(task, sampler); + const ProgramMeasurer& measurer, + tir::TRandState* rand_state, int verbose) { + Array support = space->GetSupport(task, rand_state); int iter = 1; for (int num_measured = 0; num_measured < this->total_measures; ++iter) { LOG(INFO) << "Evolutionary search: Iteration #" << iter << " | Measured: " << num_measured << "/" << this->total_measures; // `inits`: Sampled initial population, whose size is at most `this->population` LOG(INFO) << "Sampling initial population..."; - Array inits = SampleInitPopulation(support, task, space, sampler); + Array inits = SampleInitPopulation(support, task, space, rand_state); LOG(INFO) << "Initial population size: " << inits.size(); // `bests`: The best schedules according to the cost mode when explore the space using mutators LOG(INFO) << "Evolving..."; - Array bests = EvolveWithCostModel(inits, task, space, sampler); + Array bests = EvolveWithCostModel(inits, task, space, rand_state); // Pick candidates with eps greedy LOG(INFO) << "Picking with epsilon greedy where epsilon = " << eps_greedy; - Array picks = PickWithEpsGreedy(inits, bests, task, space, sampler); + Array picks = PickWithEpsGreedy(inits, bests, task, space, rand_state); // Run measurement, update cost model LOG(INFO) << "Sending " << picks.size() << " samples for measurement"; Array results = MeasureAndUpdateCostModel(task, picks, measurer, verbose); @@ -475,25 +471,25 @@ Optional EvolutionaryNode::Search(const SearchTask& task, const Search Array EvolutionaryNode::SampleInitPopulation(const Array& support, const SearchTask& task, const SearchSpace& space, - Sampler* global_sampler) { + tir::TRandState* global_rand_state) { trace_cache_.clear(); std::vector results; results.reserve(this->population); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_samplers = ForkSamplers(num_threads, global_sampler); + std::vector thread_rand_states = ForkRandStates(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); // Pick measured states int num_measured = this->population * this->init_measured_ratio; for (const Database::Entry& entry : database->GetTopK(num_measured, task)) { results.push_back(entry.trace.value()); } - auto f_proc_measured = [this, &results, &thread_samplers, &task, &space, thread_workloads]( + auto f_proc_measured = [this, &results, &thread_rand_states, &task, &space, thread_workloads]( int thread_id, int i) -> void { - Sampler* sampler = &thread_samplers[thread_id]; + tir::TRandState* rand_state = &thread_rand_states[thread_id]; const Trace& trace = results[i]; if (Optional opt_sch = - ReplayTrace(trace, task, space, sampler, thread_workloads[thread_id])) { + ReplayTrace(trace, task, space, rand_state, thread_workloads[thread_id])) { Schedule sch = opt_sch.value(); this->AddCachedTrace(CachedTrace{trace.get(), sch, Repr(sch), -1.0}); } else { @@ -505,15 +501,16 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo // Pick unmeasured states std::atomic tot_fail_ct(0); std::atomic success_ct(0); - auto f_proc_unmeasured = [this, &results, &thread_samplers, &tot_fail_ct, &task, &space, &support, - &success_ct, thread_workloads](int thread_id, int i) -> void { - Sampler* sampler = &thread_samplers[thread_id]; + auto f_proc_unmeasured = [this, &results, &thread_rand_states, &tot_fail_ct, &task, &space, + &support, &success_ct, thread_workloads](int thread_id, int i) -> void { + tir::TRandState* rand_state = &thread_rand_states[thread_id]; for (;;) { - Trace support_trace = support[sampler->SampleInt(0, support.size())]->trace().value(); + Trace support_trace = support[tir::SampleInt(rand_state, 0, support.size())]->trace().value(); Map decisions; try { - if (Optional opt_sch = ReplayTrace(Trace(support_trace->insts, decisions), task, - space, sampler, thread_workloads[thread_id])) { + if (Optional opt_sch = + ReplayTrace(Trace(support_trace->insts, decisions), task, space, rand_state, + thread_workloads[thread_id])) { Schedule sch = opt_sch.value(); Trace old_trace = sch->trace().value(); Trace trace(old_trace->insts, old_trace->decisions); @@ -547,24 +544,25 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, const SearchTask& task, const SearchSpace& space, - Sampler* global_sampler) { + tir::TRandState* global_rand_state) { // The heap to record best schedule, we do not consider schedules that are already measured // Also we use `in_heap` to make sure items in the heap are de-duplicated SizedHeap heap(this->num_measures_per_iteration); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_samplers = ForkSamplers(num_threads, global_sampler); + std::vector thread_rand_states = ForkRandStates(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); std::vector> thread_trace_samplers(num_threads); std::vector()>> thread_mutator_samplers(num_threads); std::vector trace_used; std::mutex trace_used_mutex; - auto f_set_sampler = [this, num_threads, &thread_samplers, &thread_trace_samplers, + auto f_set_sampler = [this, num_threads, &thread_rand_states, &thread_trace_samplers, &thread_mutator_samplers, &trace_used](const std::vector& scores) { for (int i = 0; i < num_threads; ++i) { - Sampler* sampler = &thread_samplers[i]; - thread_trace_samplers[i] = sampler->MakeMultinomial(scores); - thread_mutator_samplers[i] = MakeMutatorSampler(this->p_mutate, this->mutator_probs, sampler); + tir::TRandState* rand_state = &thread_rand_states[i]; + thread_trace_samplers[i] = tir::MakeMultinomial(rand_state, scores); + thread_mutator_samplers[i] = + MakeMutatorSampler(this->p_mutate, this->mutator_probs, rand_state); } trace_used = std::vector(scores.size(), 0); }; @@ -595,11 +593,11 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, // Set threaded samplers, with probability from predicated normalized throughputs f_set_sampler(scores); // The worker function - auto f_find_candidate = [&thread_samplers, &thread_trace_samplers, &thread_mutator_samplers, + auto f_find_candidate = [&thread_rand_states, &thread_trace_samplers, &thread_mutator_samplers, &trace_used, &trace_used_mutex, &sch_curr, &sch_next, &task, &space, thread_workloads, this](int thread_id, int i) { // Prepare samplers - Sampler* sampler = &thread_samplers[thread_id]; + tir::TRandState* rand_state = &thread_rand_states[thread_id]; const std::function& trace_sampler = thread_trace_samplers[thread_id]; const std::function()>& mutator_sampler = thread_mutator_samplers[thread_id]; @@ -613,10 +611,10 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, // Decision: mutate Mutator mutator = opt_mutator.value(); if (Optional opt_new_trace = - mutator->Apply(task, GetRef(cached_trace.trace), sampler)) { + mutator->Apply(task, GetRef(cached_trace.trace), rand_state)) { Trace new_trace = opt_new_trace.value(); if (Optional opt_sch = - ReplayTrace(new_trace, task, space, sampler, thread_workloads[thread_id])) { + ReplayTrace(new_trace, task, space, rand_state, thread_workloads[thread_id])) { Schedule sch = opt_sch.value(); CachedTrace new_cached_trace{new_trace.get(), sch, Repr(sch), -1.0}; this->AddCachedTrace(new_cached_trace); @@ -675,10 +673,11 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, Array EvolutionaryNode::PickWithEpsGreedy(const Array& inits, const Array& bests, const SearchTask& task, - const SearchSpace& space, Sampler* sampler) { + const SearchSpace& space, + tir::TRandState* rand_state) { int num_rands = this->num_measures_per_iteration * this->eps_greedy; int num_bests = this->num_measures_per_iteration - num_rands; - std::vector rands = sampler->SampleWithoutReplacement(inits.size(), inits.size()); + std::vector rands = tir::SampleWithoutReplacement(rand_state, inits.size(), inits.size()); Array results; results.reserve(this->num_measures_per_iteration); for (int i = 0, i_bests = 0, i_rands = 0; i < this->num_measures_per_iteration; ++i) { @@ -780,11 +779,13 @@ struct Internal { static Array SampleInitPopulation(Evolutionary self, Array support, SearchTask task, SearchSpace space, Optional seed) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return self->SampleInitPopulation(support, task, space, &seeded); + return self->SampleInitPopulation(support, task, space, &rand_state); } /*! * \brief Perform evolutionary search using genetic algorithm with the cost model @@ -798,11 +799,13 @@ struct Internal { */ static Array EvolveWithCostModel(Evolutionary self, Array inits, SearchTask task, SearchSpace space, Optional seed) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return self->EvolveWithCostModel(inits, task, space, &seeded); + return self->EvolveWithCostModel(inits, task, space, &rand_state); } /*! @@ -816,11 +819,13 @@ struct Internal { static Array PickWithEpsGreedy(Evolutionary self, Array inits, Array bests, SearchTask task, SearchSpace space, Optional seed) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return self->PickWithEpsGreedy(inits, bests, task, space, &seeded); + return self->PickWithEpsGreedy(inits, bests, task, space, &rand_state); } /*! diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index c425c8e5c8..3ed1b90e2d 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -35,8 +35,9 @@ Mutator::Mutator(String name, FApply apply) { /********** Mutator **********/ -Optional MutatorNode::Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) { - return apply_(task, trace, sampler); +Optional MutatorNode::Apply(const SearchTask& task, const Trace& trace, + tir::TRandState* rand_state) { + return apply_(task, trace, rand_state); } /********** MutateTileSize **********/ @@ -77,17 +78,17 @@ class MutatorTileSize { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) { + Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state) { // Find instruction `SamplePerfectTile` whose extent > 1 and n_splits > 1 std::vector candidates = FindCandidates(trace); if (candidates.empty()) { return NullOpt; } - const Instruction& inst = candidates[sampler->SampleInt(0, candidates.size())]; + const Instruction& inst = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; std::vector tiles = CastDecision(trace->decisions.at(inst)); int n_splits = tiles.size(); // Choose two loops - int x = sampler->SampleInt(0, n_splits); + int x = tir::SampleInt(rand_state, 0, n_splits); int y; if (tiles[x] == 1) { // need to guarantee that tiles[x] * tiles[y] > 1 @@ -98,10 +99,10 @@ class MutatorTileSize { idx.push_back(i); } } - y = idx[sampler->SampleInt(0, idx.size())]; + y = idx[tir::SampleInt(rand_state, 0, idx.size())]; } else { // sample without replacement - y = sampler->SampleInt(0, n_splits - 1); + y = tir::SampleInt(rand_state, 0, n_splits - 1); if (y >= x) { ++y; } @@ -115,7 +116,7 @@ class MutatorTileSize { int len_x, len_y; if (y != n_splits - 1) { do { - std::vector result = sampler->SamplePerfectTile(2, tiles[x] * tiles[y]); + std::vector result = tir::SamplePerfectTile(rand_state, 2, tiles[x] * tiles[y]); len_x = result[0]; len_y = result[1]; } while (len_y == tiles[y]); @@ -132,7 +133,7 @@ class MutatorTileSize { if (len_y_space.empty()) { return NullOpt; } - len_y = len_y_space[sampler->SampleInt(0, len_y_space.size())]; + len_y = len_y_space[tir::SampleInt(rand_state, 0, len_y_space.size())]; len_x = prod / len_y; } tiles[x] = len_x; @@ -142,9 +143,9 @@ class MutatorTileSize { }; Mutator MutateTileSize() { - auto f_apply = [](SearchTask task, Trace trace, void* sampler) -> Optional { + auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorTileSize mutator; - return mutator.Apply(task, trace, static_cast(sampler)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_tile_size", f_apply); } @@ -216,21 +217,21 @@ class MutatorComputeLocation { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) { + Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state) { std::vector candidates = FindCandidates(trace, task->workload); if (candidates.empty()) { return NullOpt; } - const Candidate& candidate = candidates[sampler->SampleInt(0, candidates.size())]; - int loc = candidate.locs[sampler->SampleInt(0, candidate.locs.size())]; + const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; + int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())]; return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true); } }; Mutator MutateComputeLocation() { - auto f_apply = [](SearchTask task, Trace trace, void* sampler) -> Optional { + auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorComputeLocation mutator; - return mutator.Apply(task, trace, static_cast(sampler)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_compute_location", f_apply); } @@ -308,13 +309,13 @@ class MutatorAutoUnroll { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) { + Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state) { std::vector candidates = FindCandidates(trace); if (candidates.empty()) { return NullOpt; } - const Candidate& candidate = candidates[sampler->SampleInt(0, candidates.size())]; - int result = sampler->MakeMultinomial(candidate.weights)(); + const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; + int result = tir::MakeMultinomial(rand_state, candidate.weights)(); if (result >= candidate.ori_decision) { result++; } @@ -323,9 +324,9 @@ class MutatorAutoUnroll { }; Mutator MutateAutoUnroll() { - auto f_apply = [](SearchTask task, Trace trace, void* sampler) -> Optional { + auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorAutoUnroll mutator; - return mutator.Apply(task, trace, static_cast(sampler)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_unroll_depth", f_apply); } @@ -429,7 +430,8 @@ class MutatorParallel { return Candidate(Instruction{nullptr}, {}); } - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) const { + Optional Apply(const SearchTask& task, const Trace& trace, + tir::TRandState* rand_state) const { static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); int max_extent = GetTargetNumCores(task->target, &warned_num_cores_missing) * max_jobs_per_core - 1; @@ -439,7 +441,7 @@ class MutatorParallel { } const BlockRV& block = Downcast(candidate.inst->inputs[0]); const std::vector& extent_candidates = candidate.extent_candidates; - int parallel_size = extent_candidates[sampler->SampleInt(0, extent_candidates.size())]; + int parallel_size = extent_candidates[tir::SampleInt(rand_state, 0, extent_candidates.size())]; std::vector new_insts; for (const Instruction& inst : trace->insts) { @@ -464,8 +466,8 @@ class MutatorParallel { Mutator MutateParallel(const int& max_jobs_per_core) { MutatorParallel mutator(max_jobs_per_core); - auto f_apply = [mutator](SearchTask task, Trace trace, void* sampler) -> Optional { - return mutator.Apply(task, trace, static_cast(sampler)); + auto f_apply = [mutator](SearchTask task, Trace trace, void* rand_state) -> Optional { + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_parallel", f_apply); } @@ -479,11 +481,13 @@ struct Internal { */ static Optional Apply(Mutator mutator, SearchTask task, Trace trace, Optional seed) { - Sampler seeded; - if (seed.defined()) { - seeded.Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && seed.value()->value != -1) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } - return mutator->Apply(task, trace, &seeded); + return mutator->Apply(task, trace, &rand_state); } }; diff --git a/src/meta_schedule/strategy/mutator.h b/src/meta_schedule/strategy/mutator.h index 79d1e8ccf8..286462c47c 100644 --- a/src/meta_schedule/strategy/mutator.h +++ b/src/meta_schedule/strategy/mutator.h @@ -44,10 +44,10 @@ class MutatorNode : public Object { * \brief Mutate the schedule by applying the mutation * \param task The search task * \param trace The trace to be mutated - * \param sampler The random number sampler + * \param rand_state The random state for sampling * \return The new schedule after mutation, NullOpt if mutation fails */ - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler); + Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state); static constexpr const char* _type_key = "meta_schedule.Mutator"; TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); diff --git a/src/meta_schedule/strategy/replay.cc b/src/meta_schedule/strategy/replay.cc index b6f24cf70d..0a4fb9a691 100644 --- a/src/meta_schedule/strategy/replay.cc +++ b/src/meta_schedule/strategy/replay.cc @@ -51,7 +51,7 @@ class ReplayNode : public SearchStrategyNode { * \return The best schedule found, NullOpt if no valid schedule is found */ Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, + const ProgramMeasurer& measurer, tir::TRandState* rand_state, int verbose) override; static constexpr const char* _type_key = "meta_schedule.Replay"; @@ -86,21 +86,21 @@ Replay::Replay(int batch_size, int num_trials) { /********** Search **********/ Optional ReplayNode::Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, + const ProgramMeasurer& measurer, tir::TRandState* rand_state, int verbose) { - std::vector thread_samplers; + std::vector thread_rand_states; std::vector thread_measure_inputs; - thread_samplers.reserve(this->batch_size); + thread_rand_states.reserve(this->batch_size); thread_measure_inputs.reserve(this->batch_size); for (int i = 0; i < batch_size; ++i) { - thread_samplers.emplace_back(sampler->ForkSeed()); + thread_rand_states.emplace_back(tir::ForkSeed(rand_state)); thread_measure_inputs.emplace_back(nullptr); } - auto worker = [&task, &space, &thread_samplers, &thread_measure_inputs](int thread_id, int i) { - Sampler* sampler = &thread_samplers[i]; + auto worker = [&task, &space, &thread_rand_states, &thread_measure_inputs](int thread_id, int i) { + tir::TRandState* rand_state = &thread_rand_states[i]; for (;;) { - Schedule sch = space->SampleSchedule(task, sampler); - if (space->Postprocess(task, sch, sampler)) { + Schedule sch = space->SampleSchedule(task, rand_state); + if (space->Postprocess(task, sch, rand_state)) { thread_measure_inputs[i] = MeasureInput(task, sch); break; } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 447f40cdc0..56aa63d4c2 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -23,12 +23,13 @@ namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int64_t seed, int debug_mode, +Schedule Schedule::Concrete(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); n->error_render_level_ = error_render_level; - n->sampler_.Seed(seed); + if (seed == -1) seed = std::random_device()(); + tir::RandEngine(&n->rand_state_).Seed(seed); n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); @@ -180,12 +181,13 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb ScheduleCopier::Copy(this, new_state, new_symbol_table); } -Schedule ConcreteScheduleNode::Copy(int64_t new_seed) const { +Schedule ConcreteScheduleNode::Copy(tir::TRandState new_seed) const { ObjectPtr n = make_object(); Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; n->analyzer_ = std::make_unique(); - n->sampler_.Seed(new_seed); + if (new_seed == -1) new_seed = std::random_device()(); + tir::RandEngine(&n->rand_state_).Seed(new_seed); return Schedule(std::move(n)); } @@ -218,7 +220,7 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int int max_innermost_factor, Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::SamplePerfectTile(state_, &this->sampler_, this->GetSRef(loop_rv), n, + return CreateRV(tir::SamplePerfectTile(state_, &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); } @@ -227,7 +229,7 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::SampleCategorical(state_, &this->sampler_, candidates, probs, &decision)); + return CreateRV(tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); } @@ -235,7 +237,7 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV( - tir::SampleComputeLocation(state_, &this->sampler_, this->GetSRef(block_rv), &decision)); + tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); TVM_TIR_SCHEDULE_END("sample-compute-location", this->error_render_level_); } @@ -666,7 +668,7 @@ void ConcreteScheduleNode::SoftwarePipeline(const LoopRV& loop_rv, int num_stage TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](IRModule mod, int64_t seed, int debug_mode, + .set_body_typed([](IRModule mod, tir::TRandState seed, int debug_mode, int error_render_level) -> Schedule { return Schedule::Concrete(mod, seed, debug_mode, static_cast(error_render_level)); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 429b372004..e274b633b1 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -23,7 +23,7 @@ #include #include -#include "./sampler.h" +#include "./primitive.h" #include "./utils.h" namespace tvm { @@ -42,7 +42,7 @@ class ConcreteScheduleNode : public ScheduleNode { /*! \brief The level of error rendering */ ScheduleErrorRenderLevel error_render_level_; /*! \brief Source of randomness */ - Sampler sampler_; + tir::TRandState rand_state_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ @@ -53,7 +53,7 @@ class ConcreteScheduleNode : public ScheduleNode { // `error_render_level_` is not visited // `state_` is not visited // `error_render_level_` is not visited - // `sampler_` is not visited + // `rand_state_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visitied } @@ -66,9 +66,12 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } - Schedule Copy(int64_t new_seed = -1) const override; - void Seed(int64_t new_seed = -1) final { this->sampler_.Seed(new_seed); } - int64_t ForkSeed() final { return this->sampler_.ForkSeed(); } + Schedule Copy(tir::TRandState new_seed = -1) const override; + void Seed(tir::TRandState new_seed = -1) final { + if (new_seed == -1) new_seed = std::random_device()(); + RandEngine(&this->rand_state_).Seed(new_seed); + } + tir::TRandState ForkSeed() final { return tir::ForkSeed(&this->rand_state_); } public: /******** Lookup random variables ********/ @@ -83,7 +86,7 @@ class ConcreteScheduleNode : public ScheduleNode { void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } using ScheduleNode::GetSRef; - + public: /******** Schedule: Sampling ********/ Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index c30e53f9ff..fff72d5859 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -19,26 +19,119 @@ #ifndef TVM_TIR_SCHEDULE_PRIMITIVES_PRIMITIVES_H_ #define TVM_TIR_SCHEDULE_PRIMITIVES_PRIMITIVES_H_ +#include +#include #include +#include #include namespace tvm { namespace tir { - -class Sampler; - /******** Schedule: Sampling ********/ -TVM_DLL std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler, +/*! \brief Return a seed that can be used as a new random state. */ +TRandState ForkSeed(TRandState* rand_state); +/*! + * \brief Sample an integer in [min_inclusive, max_exclusive) + * \param min_inclusive The left boundary, inclusive + * \param max_exclusive The right boundary, exclusive + * \return The integer sampled + */ +int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive); +/*! + * \brief Sample n integers in [min_inclusive, max_exclusive) + * \param min_inclusive The left boundary, inclusive + * \param max_exclusive The right boundary, exclusive + * \return The list of integers sampled + */ +std::vector SampleInts(TRandState* rand_state, int n, int min_inclusive, int max_exclusive); +/*! + * \brief Random shuffle from the begin iterator to the end. + * \param begin_it The begin iterator + * \param end_it The end iterator + */ +template +void SampleShuffle(TRandState* rand_state, RandomAccessIterator begin_it, + RandomAccessIterator end_it); +/*! + * \brief Sample n tiling factors of the specific extent + * \param n The number of parts the loop is split + * \param extent Length of the loop + * \param candidates The possible tiling factors + * \return A list of length n, the tiling factors sampled + */ +std::vector SampleTileFactor(TRandState* rand_state, int n, int extent, + const std::vector& candidates); +/*! + * \brief Sample perfect tiling factor of the specific extent + * \param n_splits The number of parts the loop is split + * \param extent Length of the loop + * \return A list of length n_splits, the tiling factors sampled, the product of which strictly + * equals to extent + */ +std::vector SamplePerfectTile(TRandState* rand_state, int n_splits, int extent); +/*! + * \brief Sample perfect tiling factor of the specific extent + * \param n_splits The number of parts the loop is split + * \param extent Length of the loop + * \param max_innermost_factor A small number indicating the max length of the innermost loop + * \return A list of length n_splits, the tiling factors sampled, the product of which strictly + * equals to extent + */ +std::vector SamplePerfectTile(TRandState* rand_state, int n_splits, int extent, + int max_innermost_factor); +/*! + * \brief Sample shape-generic tiling factors that are determined by the hardware constraints. + * \param n_splits The number of parts the loops are split + * \param max_extents Maximum length of the loops + * \param is_spatial Whether each loop is a spatial axis or not + * \param target Hardware target + * \param max_innermost_factor A small number indicating the max length of the innermost loop + * \return A list of list of length n_splits, the tiling factors sampled, all satisfying the + * maximum extents and the hardware constraints + */ +std::vector> SampleShapeGenericTiles(TRandState* rand_state, + const std::vector& n_splits, + const std::vector& max_extents, + const Target& target, + int max_innermost_factor); +/*! + * \brief Sample n floats uniformly in [min, max) + * \param min The left boundary + * \param max The right boundary + * \return The list of floats sampled + */ +std::vector SampleUniform(TRandState* rand_state, int n, double min, double max); +/*! + * \brief Sample from a Bernoulli distribution + * \param p Parameter in the Bernoulli distribution + * \return return true with probability p, and false with probability (1 - p) + */ +bool SampleBernoulli(TRandState* rand_state, double p); +/*! + * \brief Create a multinomial sampler based on the specific weights + * \param weights The weights, event probabilities + * \return The multinomial sampler + */ +std::function MakeMultinomial(TRandState* rand_state, const std::vector& weights); +/*! + * \brief Classic sampling without replacement + * \param n The population size + * \param k The number of samples to be drawn from the population + * \return A list of indices, samples drawn, unsorted and index starting from 0 + */ +std::vector SampleWithoutReplacement(TRandState* rand_state, int n, int k); + +TVM_DLL std::vector SamplePerfectTile(tir::ScheduleState self, tir::TRandState* rand_state, const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision); -TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, Sampler* sampler, +TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, tir::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); -TVM_DLL tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler* sampler, +TVM_DLL tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, tir::TRandState* rand_state, const tir::StmtSRef& block_sref, Optional* decision); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 66f3f8179d..02a16008c9 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -16,13 +16,559 @@ * specific language governing permissions and limitations * under the License. */ -#include "../sampler.h" +#include +#include + #include "../utils.h" namespace tvm { namespace tir { -std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler, +struct PrimeTable { + /*! \brief The table contains prime numbers in [2, kMaxPrime) */ + static constexpr const int kMaxPrime = 65536; + /*! \brief The exact number of prime numbers in the table */ + static constexpr const int kNumPrimes = 6542; + /*! + * \brief For each number in [2, kMaxPrime), the index of its min factor. + * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. + */ + int min_factor_idx[kMaxPrime]; + /*! \brief The prime numbers in [2, kMaxPrime) */ + std::vector primes; + /*! + * \brief The power of each prime number. + * pow_table[i, j] stores the result of pow(prime[i], j + 1) + */ + std::vector> pow_tab; + + /*! \brief Get a global instance of the prime table */ + static const PrimeTable* Global() { + static const PrimeTable table; + return &table; + } + + /*! \brief Constructor, pre-computes all info in the prime table */ + PrimeTable() { + constexpr const int64_t int_max = std::numeric_limits::max(); + // Euler's sieve: prime number in linear time + for (int i = 0; i < kMaxPrime; ++i) { + min_factor_idx[i] = -1; + } + primes.reserve(kNumPrimes); + for (int x = 2; x < kMaxPrime; ++x) { + if (min_factor_idx[x] == -1) { + min_factor_idx[x] = primes.size(); + primes.push_back(x); + } + for (size_t i = 0; i < primes.size(); ++i) { + int factor = primes[i]; + int y = x * factor; + if (y >= kMaxPrime) { + break; + } + min_factor_idx[y] = i; + if (x % factor == 0) { + break; + } + } + } + ICHECK_EQ(static_cast(primes.size()), int(kNumPrimes)); + // Calculate the power table for each prime number + pow_tab.reserve(primes.size()); + for (int prime : primes) { + std::vector tab; + tab.reserve(32); + for (int64_t pow = prime; pow <= int_max; pow *= prime) { + tab.push_back(pow); + } + tab.shrink_to_fit(); + pow_tab.emplace_back(std::move(tab)); + } + } + /*! + * \brief Factorize a number n, and return in a cryptic format + * \param n The number to be factorized + * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] + * For each pair (i, j), we define + * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) + * (primes[i], j) if i != -1 + * Then the factorization is + * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) + */ + std::vector> Factorize(int n) const { + std::vector> result; + result.reserve(16); + int i = 0, n_primes = primes.size(); + // Phase 1: n >= kMaxPrime + for (int j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + if (j != 0) { + result.emplace_back(i, j); + } + } + // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number + if (n >= kMaxPrime) { + result.emplace_back(-1, n); + return result; + } + // Phase 2: n < kMaxPrime + for (int j; n > 1;) { + int i = min_factor_idx[n]; + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + result.emplace_back(i, j); + } + return result; + } +}; + +TRandState ForkSeed(TRandState* rand_state) { + // In order for reproducibility, we computer the new seed using RNG's random state and a + // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. + TRandState ret = (RandEngine(rand_state)() * 32767) % 1999999973; + return ret; +} + +int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive) { + RandEngine rand_(rand_state); + + if (min_inclusive + 1 == max_exclusive) { + return min_inclusive; + } + std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); + return dist(rand_); +} + +std::vector SampleInts(TRandState* rand_state, int n, int min_inclusive, int max_exclusive) { + RandEngine rand_(rand_state); + std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); + std::vector result; + result.reserve(n); + for (int i = 0; i < n; ++i) { + result.push_back(dist(rand_)); + } + return result; +} + +template +void SampleShuffle(TRandState* rand_state, RandomAccessIterator begin_it, + RandomAccessIterator end_it) { + RandEngine rand_(rand_state); + std::shuffle(begin_it, end_it, rand_); +} + +std::vector SampleUniform(TRandState* rand_state, int n, double min, double max) { + RandEngine rand_(rand_state); + std::uniform_real_distribution dist(min, max); + std::vector result; + result.reserve(n); + for (int i = 0; i < n; ++i) { + result.push_back(dist(rand_)); + } + return result; +} + +bool SampleBernoulli(TRandState* rand_state, double p) { + RandEngine rand_(rand_state); + std::bernoulli_distribution dist(p); + return dist(rand_); +} + +std::function MakeMultinomial(TRandState* rand_state, const std::vector& weights) { + RandEngine rand_(rand_state); + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + std::uniform_real_distribution dist(0.0, sum); + auto sampler = [rand_state, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { + RandEngine rand_(rand_state); + double p = dist(rand_); + int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; + return sampler; +} + +std::vector SampleWithoutReplacement(TRandState* rand_state, int n, int k) { + if (k == 1) { + return {SampleInt(rand_state, 0, n)}; + } + if (k == 2) { + int result0 = SampleInt(rand_state, 0, n); + int result1 = SampleInt(rand_state, 0, n - 1); + if (result1 >= result0) { + result1 += 1; + } + return {result0, result1}; + } + std::vector order(n); + for (int i = 0; i < n; ++i) { + order[i] = i; + } + for (int i = 0; i < k; ++i) { + int j = SampleInt(rand_state, i, n); + if (i != j) { + std::swap(order[i], order[j]); + } + } + return {order.begin(), order.begin() + k}; +} + +std::vector SampleTileFactor(TRandState* rand_state, int n, int extent, + const std::vector& candidates) { + RandEngine rand_(rand_state); + constexpr int kMaxTrials = 100; + std::uniform_int_distribution<> dist(0, static_cast(candidates.size()) - 1); + std::vector sample(n, -1); + for (int trial = 0; trial < kMaxTrials; ++trial) { + int64_t product = 1; + for (int i = 1; i < n; ++i) { + int value = candidates[dist(rand_)]; + product *= value; + if (product > extent) { + break; + } + sample[i] = value; + } + if (product <= extent) { + sample[0] = (extent + product - 1) / product; + return sample; + } + } + sample[0] = extent; + for (int i = 1; i < n; ++i) { + sample[i] = 1; + } + return sample; +} + +std::vector SamplePerfectTile(TRandState* rand_state, int n_splits, int extent) { + CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; + CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; + // Handle special case that we can potentially accelerate + if (n_splits == 1) { + return {extent}; + } + if (extent == 1) { + return std::vector(n_splits, 1); + } + // Enumerate each pair (i, j), we define + // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) + // (primes[i], j) if i != -1 + // Then the factorization is + // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) + const PrimeTable* prime_tab = PrimeTable::Global(); + std::vector> factorized = prime_tab->Factorize(extent); + if (n_splits == 2) { + // n_splits = 2, this can be taken special care of, + // because general reservoir sampling can be avoided to accelerate the sampling + int result0 = 1; + int result1 = 1; + for (const std::pair& ij : factorized) { + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + (SampleInt(rand_state, 0, 2) ? result1 : result0) *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int p = ij.second; + const int* pow = prime_tab->pow_tab[ij.first].data() - 1; + int x1 = SampleInt(rand_state, 0, p + 1); + int x2 = p - x1; + if (x1 != 0) { + result0 *= pow[x1]; + } + if (x2 != 0) { + result1 *= pow[x2]; + } + } + return {result0, result1}; + } + // Data range: + // 2 <= extent <= 2^31 - 1 + // 3 <= n_splits <= max tiling splits + // 1 <= p <= 31 + std::vector result(n_splits, 1); + for (const std::pair& ij : factorized) { + // Handle special cases to accelerate sampling + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + result[SampleInt(rand_state, 0, n_splits)] *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int p = ij.second; + if (p == 1) { + result[SampleInt(rand_state, 0, n_splits)] *= prime_tab->primes[ij.first]; + continue; + } + // The general case. We have to sample uniformly from the solution of: + // x_1 + x_2 + ... + x_{n_splits} = p + // where x_i >= 0 + // Data range: + // 2 <= p <= 31 + // 3 <= n_splits <= max tiling splits + std::vector sampled = SampleWithoutReplacement(rand_state, p + n_splits - 1, n_splits - 1); + std::sort(sampled.begin(), sampled.end()); + sampled.push_back(p + n_splits - 1); + const int* pow = prime_tab->pow_tab[ij.first].data() - 1; + for (int i = 0, last = -1; i < n_splits; ++i) { + int x = sampled[i] - last - 1; + last = sampled[i]; + if (x != 0) { + result[i] *= pow[x]; + } + } + } + return result; +} + +std::vector SamplePerfectTile(TRandState* rand_state, int n_splits, int extent, + int max_innermost_factor) { + if (max_innermost_factor == -1) { + return SamplePerfectTile(rand_state, n_splits, extent); + } + CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; + std::vector innermost_candidates; + innermost_candidates.reserve(max_innermost_factor); + for (int i = 1; i <= max_innermost_factor; ++i) { + if (extent % i == 0) { + innermost_candidates.push_back(i); + } + } + // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. + // We should do multiple factorization to weight the choices. However, it would lead to slower + // sampling speed. On the other hand, considering potential tricks we might do on the innermost + // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add + // more heuristics in the future + int innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; + std::vector result = SamplePerfectTile(rand_state, n_splits - 1, extent / innermost); + result.push_back(innermost); + return result; +} + +static inline int ExtractInt(const Target& target, const char* name) { + if (Optional v = target->GetAttr(name)) { + return v.value()->value; + } + LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; + throw; +} + +static inline bool IsCudaTarget(const Target& target) { + if (Optional v = target->GetAttr("kind")) { + return v.value() == "cuda"; + } + return false; +} + +std::vector> SampleShapeGenericTiles(TRandState* rand_state, + const std::vector& n_splits, + const std::vector& max_extents, + const Target& target, + int max_innermost_factor) { + std::vector> ret_split_factors; + + if (IsCudaTarget(target)) { + // The following factorization scheme is built under the assumption that: (1) The target is + // CUDA, and (2) The tiling structure is SSSRRSRS. + + // extract all the hardware parameters + const struct HardwareConstraints { + int max_shared_memory_per_block; + int max_local_memory_per_block; + int max_threads_per_block; + int max_innermost_factor; + int max_vthread; + } constraints = {ExtractInt(target, "shared_memory_per_block"), + ExtractInt(target, "registers_per_block"), + ExtractInt(target, "max_threads_per_block"), max_innermost_factor, 8}; + + for (const int n_split : n_splits) { + ret_split_factors.push_back(std::vector(n_split, 1)); + } + + // sample the number of threads per block + const int warp_size = ExtractInt(target, "warp_size"); + int num_threads_per_block = + SampleInt(rand_state, 1, constraints.max_threads_per_block / warp_size) * warp_size; + // find all the possible factors of the number of threads per block + int num_spatial_axes = 0; + size_t last_spatial_iter_id = -1; + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + CHECK(n_splits[iter_id] == 4 || n_splits[iter_id] == 2) + << "The tiling structure is not SSSRRSRS"; + if (n_splits[iter_id] == 4) { + ++num_spatial_axes; + last_spatial_iter_id = iter_id; + } + } + + bool all_below_max_extents; + std::vector num_threads_factor_scheme; + do { + all_below_max_extents = true; + + num_threads_factor_scheme = + SamplePerfectTile(rand_state, num_spatial_axes, num_threads_per_block); + for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (n_splits[iter_id] == 4) { + if (num_threads_factor_scheme[spatial_iter_id] > max_extents[iter_id]) { + all_below_max_extents = false; + } + ++spatial_iter_id; + } + } // for (iter_id ∈ [0, split_steps_info.size())) + } while (!all_below_max_extents); + + // do the looping again and assign the factors + for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (n_splits[iter_id] == 4) { + ret_split_factors[iter_id][1] = num_threads_factor_scheme[spatial_iter_id]; + ++spatial_iter_id; + } + } + + // factor[0] (vthread) + int reg_usage = num_threads_per_block, shmem_usage = 0; + + auto sample_factors = [&](std::function continue_predicate, + std::function max_extent, + std::function factor_to_assign) { + std::vector iter_max_extents; + std::vector factors_to_assign; + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + size_t iter_max_extent = max_extent(iter_id), factor_to_assign; + + std::uniform_int_distribution<> dist(1, iter_max_extent); + factor_to_assign = SampleInt(rand_state, 1, iter_max_extent); + + if (n_splits[iter_id] == 4) { + reg_usage *= factor_to_assign; + } else { + shmem_usage *= factor_to_assign; + } + iter_max_extents.push_back(iter_max_extent); + factors_to_assign.push_back(factor_to_assign); + } + // shuffle the factors + std::vector factors_to_assign_bak = factors_to_assign; + SampleShuffle(rand_state, factors_to_assign.begin(), factors_to_assign.end()); + // make sure that the shuffle is valid + bool valid_shuffle = true; + std::vector::iterator iter_max_extents_it = iter_max_extents.begin(), + factors_to_assign_it = factors_to_assign.begin(); + + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + int iter_max_extent = *iter_max_extents_it; + if (*factors_to_assign_it > iter_max_extent) { + valid_shuffle = false; + } + ++iter_max_extents_it; + ++factors_to_assign_it; + } + if (!valid_shuffle) { + factors_to_assign = std::move(factors_to_assign_bak); + } + // do the actual assignment + factors_to_assign_it = factors_to_assign.begin(); + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + factor_to_assign(iter_id) = *factors_to_assign_it; + ++factors_to_assign_it; + } + }; + + sample_factors( + [&](const size_t iter_id) -> bool { + return (n_splits[iter_id] != 4) || (iter_id != last_spatial_iter_id); + }, + [&](const size_t iter_id) -> int { + size_t max_vthread_extent = std::min( + constraints.max_vthread, max_extents[iter_id] / ret_split_factors[iter_id][1]); + max_vthread_extent = + std::min(constraints.max_vthread, constraints.max_local_memory_per_block / reg_usage); + return max_vthread_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][0]; }); + + // factor[3] (innermost) + sample_factors( + [&](const size_t iter_id) -> bool { + return (n_splits[iter_id] != 4) || (iter_id == last_spatial_iter_id); + }, + [&](const size_t iter_id) -> int { + int max_innermost_extent = + std::min(max_innermost_factor, max_extents[iter_id] / ret_split_factors[iter_id][0] / + ret_split_factors[iter_id][1]); + max_innermost_extent = + std::min(max_innermost_extent, constraints.max_local_memory_per_block / reg_usage); + return max_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][3]; }); + // factor[2] + sample_factors([&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 4); }, + [&](const size_t iter_id) -> size_t { + size_t max_2nd_innermost_extent = + std::min(max_extents[iter_id] / ret_split_factors[iter_id][0] / + ret_split_factors[iter_id][1] / ret_split_factors[iter_id][3], + constraints.max_local_memory_per_block / reg_usage); + return max_2nd_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][2]; }); + + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (n_splits[iter_id] == 4) { + shmem_usage += ret_split_factors[iter_id][0] * ret_split_factors[iter_id][1] * + ret_split_factors[iter_id][2] * ret_split_factors[iter_id][3]; + } + } + if (shmem_usage > static_cast(constraints.max_shared_memory_per_block / sizeof(float))) { + LOG(FATAL) << "shmem_usage goes out-of-range"; + } + // repeat similar procedure for reduction axes + // rfactor[1] (innermost) + sample_factors( + [&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 2); }, + [&](const size_t iter_id) -> int { + int max_innermost_extent = std::min(max_innermost_factor, max_extents[iter_id]); + max_innermost_extent = std::min(max_innermost_extent, + static_cast(constraints.max_shared_memory_per_block / + sizeof(float) / shmem_usage)); + return max_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][1]; }); + // rfactor[0] + sample_factors([&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 2); }, + [&](const size_t iter_id) -> size_t { + size_t max_2nd_innermost_extent = + std::min(max_extents[iter_id] / ret_split_factors[iter_id][1], + static_cast(constraints.max_shared_memory_per_block / + sizeof(float) / shmem_usage)); + return max_2nd_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][0]; }); + } // if (IsCudaTarget(target)) + return ret_split_factors; +} + +std::vector SamplePerfectTile(tir::ScheduleState self, TRandState* rand_state, const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision) { @@ -52,7 +598,7 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler result[0] = len; } else { // Case 3. Use fresh new sampling result - std::vector sampled = sampler->SamplePerfectTile(n, extent, max_innermost_factor); + std::vector sampled = SamplePerfectTile(rand_state, n, extent, max_innermost_factor); result = std::vector(sampled.begin(), sampled.end()); ICHECK_LE(sampled.back(), max_innermost_factor); } @@ -60,7 +606,7 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler return result; } -int64_t SampleCategorical(tir::ScheduleState self, Sampler* sampler, +int64_t SampleCategorical(tir::ScheduleState self, TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { int i = -1; @@ -71,14 +617,14 @@ int64_t SampleCategorical(tir::ScheduleState self, Sampler* sampler, CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - i = sampler->MakeMultinomial(AsVector(probs))(); + i = MakeMultinomial(rand_state, AsVector(probs))(); ICHECK(0 <= i && i < n); } *decision = Integer(i); return candidates[i]; } -tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler* sampler, +tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, TRandState* rand_state, const tir::StmtSRef& block_sref, Optional* decision) { // Find all possible compute-at locations Array loop_srefs = tir::CollectComputeLocation(self, block_sref); @@ -111,7 +657,7 @@ tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler* sampler, } } else { // Sample possible combinations - i = sampler->SampleInt(-2, choices.size()); + i = SampleInt(rand_state, -2, choices.size()); if (i >= 0) { i = choices[i]; } diff --git a/src/tir/schedule/sampler.cc b/src/tir/schedule/sampler.cc deleted file mode 100644 index 52d62b88f5..0000000000 --- a/src/tir/schedule/sampler.cc +++ /dev/null @@ -1,592 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include "./sampler.h" - -#include -#include - -#include - -namespace tvm { -namespace tir { - -struct PrimeTable { - /*! \brief The table contains prime numbers in [2, kMaxPrime) */ - static constexpr const int kMaxPrime = 65536; - /*! \brief The exact number of prime numbers in the table */ - static constexpr const int kNumPrimes = 6542; - /*! - * \brief For each number in [2, kMaxPrime), the index of its min factor. - * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. - */ - int min_factor_idx[kMaxPrime]; - /*! \brief The prime numbers in [2, kMaxPrime) */ - std::vector primes; - /*! - * \brief The power of each prime number. - * pow_table[i, j] stores the result of pow(prime[i], j + 1) - */ - std::vector> pow_tab; - - /*! \brief Get a global instance of the prime table */ - static const PrimeTable* Global() { - static const PrimeTable table; - return &table; - } - - /*! \brief Constructor, pre-computes all info in the prime table */ - PrimeTable() { - constexpr const int64_t int_max = std::numeric_limits::max(); - // Euler's sieve: prime number in linear time - for (int i = 0; i < kMaxPrime; ++i) { - min_factor_idx[i] = -1; - } - primes.reserve(kNumPrimes); - for (int x = 2; x < kMaxPrime; ++x) { - if (min_factor_idx[x] == -1) { - min_factor_idx[x] = primes.size(); - primes.push_back(x); - } - for (size_t i = 0; i < primes.size(); ++i) { - int factor = primes[i]; - int y = x * factor; - if (y >= kMaxPrime) { - break; - } - min_factor_idx[y] = i; - if (x % factor == 0) { - break; - } - } - } - ICHECK_EQ(static_cast(primes.size()), int(kNumPrimes)); - // Calculate the power table for each prime number - pow_tab.reserve(primes.size()); - for (int prime : primes) { - std::vector tab; - tab.reserve(32); - for (int64_t pow = prime; pow <= int_max; pow *= prime) { - tab.push_back(pow); - } - tab.shrink_to_fit(); - pow_tab.emplace_back(std::move(tab)); - } - } - /*! - * \brief Factorize a number n, and return in a cryptic format - * \param n The number to be factorized - * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] - * For each pair (i, j), we define - * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) - * (primes[i], j) if i != -1 - * Then the factorization is - * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) - */ - std::vector> Factorize(int n) const { - std::vector> result; - result.reserve(16); - int i = 0, n_primes = primes.size(); - // Phase 1: n >= kMaxPrime - for (int j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { - for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { - } - if (j != 0) { - result.emplace_back(i, j); - } - } - // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number - if (n >= kMaxPrime) { - result.emplace_back(-1, n); - return result; - } - // Phase 2: n < kMaxPrime - for (int j; n > 1;) { - int i = min_factor_idx[n]; - for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { - } - result.emplace_back(i, j); - } - return result; - } -}; - -int Sampler::ForkSeed() { - uint32_t a = this->rand_(); - uint32_t b = this->rand_(); - uint32_t c = this->rand_(); - uint32_t d = this->rand_(); - return (a ^ b) * (c ^ d) % 1145141; -} - -void Sampler::Seed(int seed) { this->rand_.seed(seed); } - -int Sampler::SampleInt(int min_inclusive, int max_exclusive) { - if (min_inclusive + 1 == max_exclusive) { - return min_inclusive; - } - std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); - return dist(rand_); -} - -std::vector Sampler::SampleInts(int n, int min_inclusive, int max_exclusive) { - std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); - std::vector result; - result.reserve(n); - for (int i = 0; i < n; ++i) { - result.push_back(dist(rand_)); - } - return result; -} - -std::vector Sampler::SampleUniform(int n, double min, double max) { - std::uniform_real_distribution dist(min, max); - std::vector result; - result.reserve(n); - for (int i = 0; i < n; ++i) { - result.push_back(dist(rand_)); - } - return result; -} - -bool Sampler::SampleBernoulli(double p) { - std::bernoulli_distribution dist(p); - return dist(rand_); -} - -std::function Sampler::MakeMultinomial(const std::vector& weights) { - std::vector sums; - sums.reserve(weights.size()); - double sum = 0.0; - for (double w : weights) { - sums.push_back(sum += w); - } - std::uniform_real_distribution dist(0.0, sum); - auto sampler = [this, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { - double p = dist(rand_); - int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); - int n = sums.size(); - CHECK_LE(0, idx); - CHECK_LE(idx, n); - return (idx == n) ? (n - 1) : idx; - }; - return sampler; -} - -std::vector Sampler::SampleWithoutReplacement(int n, int k) { - if (k == 1) { - return {SampleInt(0, n)}; - } - if (k == 2) { - int result0 = SampleInt(0, n); - int result1 = SampleInt(0, n - 1); - if (result1 >= result0) { - result1 += 1; - } - return {result0, result1}; - } - std::vector order(n); - for (int i = 0; i < n; ++i) { - order[i] = i; - } - for (int i = 0; i < k; ++i) { - int j = SampleInt(i, n); - if (i != j) { - std::swap(order[i], order[j]); - } - } - return {order.begin(), order.begin() + k}; -} - -std::vector Sampler::SampleTileFactor(int n, int extent, const std::vector& candidates) { - constexpr int kMaxTrials = 100; - std::uniform_int_distribution<> dist(0, static_cast(candidates.size()) - 1); - std::vector sample(n, -1); - for (int trial = 0; trial < kMaxTrials; ++trial) { - int64_t product = 1; - for (int i = 1; i < n; ++i) { - int value = candidates[dist(rand_)]; - product *= value; - if (product > extent) { - break; - } - sample[i] = value; - } - if (product <= extent) { - sample[0] = (extent + product - 1) / product; - return sample; - } - } - sample[0] = extent; - for (int i = 1; i < n; ++i) { - sample[i] = 1; - } - return sample; -} - -std::vector Sampler::SamplePerfectTile(int n_splits, int extent) { - CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; - CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; - // Handle special case that we can potentially accelerate - if (n_splits == 1) { - return {extent}; - } - if (extent == 1) { - return std::vector(n_splits, 1); - } - // Enumerate each pair (i, j), we define - // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) - // (primes[i], j) if i != -1 - // Then the factorization is - // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) - const PrimeTable* prime_tab = PrimeTable::Global(); - std::vector> factorized = prime_tab->Factorize(extent); - if (n_splits == 2) { - // n_splits = 2, this can be taken special care of, - // because general reservoir sampling can be avoided to accelerate the sampling - int result0 = 1; - int result1 = 1; - for (const std::pair& ij : factorized) { - // Case 1: (a, p) = (j, 1), where j is a prime number - if (ij.first == -1) { - (SampleInt(0, 2) ? result1 : result0) *= ij.second; - continue; - } - // Case 2: (a = primes[i], p = 1) - int p = ij.second; - const int* pow = prime_tab->pow_tab[ij.first].data() - 1; - int x1 = SampleInt(0, p + 1); - int x2 = p - x1; - if (x1 != 0) { - result0 *= pow[x1]; - } - if (x2 != 0) { - result1 *= pow[x2]; - } - } - return {result0, result1}; - } - // Data range: - // 2 <= extent <= 2^31 - 1 - // 3 <= n_splits <= max tiling splits - // 1 <= p <= 31 - std::vector result(n_splits, 1); - for (const std::pair& ij : factorized) { - // Handle special cases to accelerate sampling - // Case 1: (a, p) = (j, 1), where j is a prime number - if (ij.first == -1) { - result[SampleInt(0, n_splits)] *= ij.second; - continue; - } - // Case 2: (a = primes[i], p = 1) - int p = ij.second; - if (p == 1) { - result[SampleInt(0, n_splits)] *= prime_tab->primes[ij.first]; - continue; - } - // The general case. We have to sample uniformly from the solution of: - // x_1 + x_2 + ... + x_{n_splits} = p - // where x_i >= 0 - // Data range: - // 2 <= p <= 31 - // 3 <= n_splits <= max tiling splits - std::vector sampled = SampleWithoutReplacement(p + n_splits - 1, n_splits - 1); - std::sort(sampled.begin(), sampled.end()); - sampled.push_back(p + n_splits - 1); - const int* pow = prime_tab->pow_tab[ij.first].data() - 1; - for (int i = 0, last = -1; i < n_splits; ++i) { - int x = sampled[i] - last - 1; - last = sampled[i]; - if (x != 0) { - result[i] *= pow[x]; - } - } - } - return result; -} - -std::vector Sampler::SamplePerfectTile(int n_splits, int extent, int max_innermost_factor) { - if (max_innermost_factor == -1) { - return this->SamplePerfectTile(n_splits, extent); - } - CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; - std::vector innermost_candidates; - innermost_candidates.reserve(max_innermost_factor); - for (int i = 1; i <= max_innermost_factor; ++i) { - if (extent % i == 0) { - innermost_candidates.push_back(i); - } - } - // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. - // We should do multiple factorization to weight the choices. However, it would lead to slower - // sampling speed. On the other hand, considering potential tricks we might do on the innermost - // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add - // more heuristics in the future - int innermost = innermost_candidates[SampleInt(0, innermost_candidates.size())]; - std::vector result = SamplePerfectTile(n_splits - 1, extent / innermost); - result.push_back(innermost); - return result; -} - -static inline int ExtractInt(const Target& target, const char* name) { - if (Optional v = target->GetAttr(name)) { - return v.value()->value; - } - LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; - throw; -} - -static inline bool IsCudaTarget(const Target& target) { - if (Optional v = target->GetAttr("kind")) { - return v.value() == "cuda"; - } - return false; -} - -std::vector> Sampler::SampleShapeGenericTiles( - const std::vector& n_splits, const std::vector& max_extents, - const Target& target, int max_innermost_factor) { - std::vector> ret_split_factors; - - if (IsCudaTarget(target)) { - // The following factorization scheme is built under the assumption that: (1) The target is - // CUDA, and (2) The tiling structure is SSSRRSRS. - - // extract all the hardware parameters - const struct HardwareConstraints { - int max_shared_memory_per_block; - int max_local_memory_per_block; - int max_threads_per_block; - int max_innermost_factor; - int max_vthread; - } constraints = { - ExtractInt(target, "shared_memory_per_block"), ExtractInt(target, "registers_per_block"), - ExtractInt(target, "max_threads_per_block"), max_innermost_factor, 8}; - - for (const int n_split : n_splits) { - ret_split_factors.push_back(std::vector(n_split, 1)); - } - - // sample the number of threads per block - const int warp_size = ExtractInt(target, "warp_size"); - int num_threads_per_block = - SampleInt(1, constraints.max_threads_per_block / warp_size) * warp_size; - // find all the possible factors of the number of threads per block - int num_spatial_axes = 0; - size_t last_spatial_iter_id = -1; - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - CHECK(n_splits[iter_id] == 4 || n_splits[iter_id] == 2) - << "The tiling structure is not SSSRRSRS"; - if (n_splits[iter_id] == 4) { - ++num_spatial_axes; - last_spatial_iter_id = iter_id; - } - } - - bool all_below_max_extents; - std::vector num_threads_factor_scheme; - do { - all_below_max_extents = true; - - num_threads_factor_scheme = - SamplePerfectTile(num_spatial_axes, num_threads_per_block); - for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (n_splits[iter_id] == 4) { - if (num_threads_factor_scheme[spatial_iter_id] > max_extents[iter_id]) { - all_below_max_extents = false; - } - ++spatial_iter_id; - } - } // for (iter_id ∈ [0, split_steps_info.size())) - } while (!all_below_max_extents); - - // do the looping again and assign the factors - for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (n_splits[iter_id] == 4) { - ret_split_factors[iter_id][1] = num_threads_factor_scheme[spatial_iter_id]; - ++spatial_iter_id; - } - } - - // factor[0] (vthread) - int reg_usage = num_threads_per_block, shmem_usage = 0; - - auto sample_factors = [&](std::function continue_predicate, - std::function max_extent, - std::function factor_to_assign) { - std::vector iter_max_extents; - std::vector factors_to_assign; - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - size_t iter_max_extent = max_extent(iter_id), factor_to_assign; - - std::uniform_int_distribution<> dist(1, iter_max_extent); - factor_to_assign = SampleInt(1, iter_max_extent); - - if (n_splits[iter_id] == 4) { - reg_usage *= factor_to_assign; - } else { - shmem_usage *= factor_to_assign; - } - iter_max_extents.push_back(iter_max_extent); - factors_to_assign.push_back(factor_to_assign); - } - // shuffle the factors - std::vector factors_to_assign_bak = factors_to_assign; - Shuffle(factors_to_assign.begin(), factors_to_assign.end()); - // make sure that the shuffle is valid - bool valid_shuffle = true; - std::vector::iterator iter_max_extents_it = iter_max_extents.begin(), - factors_to_assign_it = factors_to_assign.begin(); - - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - int iter_max_extent = *iter_max_extents_it; - if (*factors_to_assign_it > iter_max_extent) { - valid_shuffle = false; - } - ++iter_max_extents_it; - ++factors_to_assign_it; - } - if (!valid_shuffle) { - factors_to_assign = std::move(factors_to_assign_bak); - } - // do the actual assignment - factors_to_assign_it = factors_to_assign.begin(); - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - factor_to_assign(iter_id) = *factors_to_assign_it; - ++factors_to_assign_it; - } - }; - - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 4) || - (iter_id != last_spatial_iter_id); - }, - [&](const size_t iter_id) -> int { - size_t max_vthread_extent = - std::min(constraints.max_vthread, - max_extents[iter_id] / ret_split_factors[iter_id][1]); - max_vthread_extent = - std::min(constraints.max_vthread, - constraints.max_local_memory_per_block / reg_usage); - return max_vthread_extent; - }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][0]; - } - ); - - // factor[3] (innermost) - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 4) || - (iter_id == last_spatial_iter_id); - }, - [&](const size_t iter_id) -> int { - int max_innermost_extent = - std::min(max_innermost_factor, - max_extents[iter_id] / ret_split_factors[iter_id][0] - / ret_split_factors[iter_id][1]); - max_innermost_extent = - std::min(max_innermost_extent, - constraints.max_local_memory_per_block / reg_usage); - return max_innermost_extent; - }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][3]; - } - ); - // factor[2] - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 4); - }, - [&](const size_t iter_id) -> size_t { - size_t max_2nd_innermost_extent = - std::min(max_extents[iter_id] / ret_split_factors[iter_id][0] - / ret_split_factors[iter_id][1] / ret_split_factors[iter_id][3], - constraints.max_local_memory_per_block / reg_usage - ); - return max_2nd_innermost_extent; - }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][2]; - } - ); - - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (n_splits[iter_id] == 4) { - shmem_usage += ret_split_factors[iter_id][0] * ret_split_factors[iter_id][1] - * ret_split_factors[iter_id][2] * ret_split_factors[iter_id][3]; - } - } - if (shmem_usage > static_cast(constraints.max_shared_memory_per_block / sizeof(float))) { - LOG(FATAL) << "shmem_usage goes out-of-range"; - } - // repeat similar procedure for reduction axes - // rfactor[1] (innermost) - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 2); - }, - [&](const size_t iter_id) -> int { - int max_innermost_extent = - std::min(max_innermost_factor, max_extents[iter_id]); - max_innermost_extent = - std::min(max_innermost_extent, - static_cast( - constraints.max_shared_memory_per_block / sizeof(float) / shmem_usage - )); - return max_innermost_extent; - }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][1]; - } - ); - // rfactor[0] - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 2); - }, - [&](const size_t iter_id) -> size_t { - size_t max_2nd_innermost_extent = - std::min(max_extents[iter_id] / ret_split_factors[iter_id][1], - static_cast( - constraints.max_shared_memory_per_block / sizeof(float) / shmem_usage - )); - return max_2nd_innermost_extent; - }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][0]; - } - ); - } // if (IsCudaTarget(target)) - return ret_split_factors; -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/schedule/sampler.h b/src/tir/schedule/sampler.h deleted file mode 100644 index 5aa87f984b..0000000000 --- a/src/tir/schedule/sampler.h +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_TIR_SCHEDULE_SAMPLER_H_ -#define TVM_TIR_SCHEDULE_SAMPLER_H_ - -#include -#include -#include -#include - -namespace tvm { - -class Target; - -namespace tir { - -/*! \brief Random number sampler used for sampling in meta schedule */ -class Sampler { - public: - /*! \brief Return a seed that can be used to create a new sampler */ - int ForkSeed(); - /*! \brief Re-seed the random number generator */ - void Seed(int seed); - /*! - * \brief Sample an integer in [min_inclusive, max_exclusive) - * \param min_inclusive The left boundary, inclusive - * \param max_exclusive The right boundary, exclusive - * \return The integer sampled - */ - int SampleInt(int min_inclusive, int max_exclusive); - /*! - * \brief Sample n integers in [min_inclusive, max_exclusive) - * \param min_inclusive The left boundary, inclusive - * \param max_exclusive The right boundary, exclusive - * \return The list of integers sampled - */ - std::vector SampleInts(int n, int min_inclusive, int max_exclusive); - /*! - * \brief Random shuffle from the begin iterator to the end. - * \param begin_it The begin iterator - * \param end_it The end iterator - */ - template - void Shuffle(RandomAccessIterator begin_it, RandomAccessIterator end_it); - /*! - * \brief Sample n tiling factors of the specific extent - * \param n The number of parts the loop is split - * \param extent Length of the loop - * \param candidates The possible tiling factors - * \return A list of length n, the tiling factors sampled - */ - std::vector SampleTileFactor(int n, int extent, const std::vector& candidates); - /*! - * \brief Sample perfect tiling factor of the specific extent - * \param n_splits The number of parts the loop is split - * \param extent Length of the loop - * \return A list of length n_splits, the tiling factors sampled, the product of which strictly - * equals to extent - */ - std::vector SamplePerfectTile(int n_splits, int extent); - /*! - * \brief Sample perfect tiling factor of the specific extent - * \param n_splits The number of parts the loop is split - * \param extent Length of the loop - * \param max_innermost_factor A small number indicating the max length of the innermost loop - * \return A list of length n_splits, the tiling factors sampled, the product of which strictly - * equals to extent - */ - std::vector SamplePerfectTile(int n_splits, int extent, int max_innermost_factor); - /*! - * \brief Sample shape-generic tiling factors that are determined by the hardware constraints. - * \param n_splits The number of parts the loops are split - * \param max_extents Maximum length of the loops - * \param is_spatial Whether each loop is a spatial axis or not - * \param target Hardware target - * \param max_innermost_factor A small number indicating the max length of the innermost loop - * \return A list of list of length n_splits, the tiling factors sampled, all satisfying the - * maximum extents and the hardware constraints - */ - std::vector> SampleShapeGenericTiles(const std::vector& n_splits, - const std::vector& max_extents, - const Target& target, - int max_innermost_factor); - /*! - * \brief Sample n floats uniformly in [min, max) - * \param min The left boundary - * \param max The right boundary - * \return The list of floats sampled - */ - std::vector SampleUniform(int n, double min, double max); - /*! - * \brief Sample from a Bernoulli distribution - * \param p Parameter in the Bernoulli distribution - * \return return true with probability p, and false with probability (1 - p) - */ - bool SampleBernoulli(double p); - /*! - * \brief Create a multinomial sampler based on the specific weights - * \param weights The weights, event probabilities - * \return The multinomial sampler - */ - std::function MakeMultinomial(const std::vector& weights); - /*! - * \brief Classic sampling without replacement - * \param n The population size - * \param k The number of samples to be drawn from the population - * \return A list of indices, samples drawn, unsorted and index starting from 0 - */ - std::vector SampleWithoutReplacement(int n, int k); - /*! - * \brief Constructor. Construct a sampler seeded with std::random_device - */ - Sampler() : Sampler(std::random_device /**/ {}()) {} - /*! - * \brief Constructor. Construct a sampler seeded with the specific integer - * \param seed The random seed - */ - explicit Sampler(int seed) : rand_(seed) {} - - private: - /*! \brief The random number generator */ - std::minstd_rand rand_; -}; - - -template -void Sampler::Shuffle(RandomAccessIterator begin_it, RandomAccessIterator end_it) { - std::shuffle(begin_it, end_it, rand_); -} - - -} // namespace tir -} // namespace tvm - -#endif // TVM_TIR_SCHEDULE_SAMPLER_H_ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e2883ddb8f..57e3cd63b7 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -21,24 +21,26 @@ namespace tvm { namespace tir { -Schedule Schedule::Traced(IRModule mod, int64_t seed, int debug_mode, +Schedule Schedule::Traced(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); n->error_render_level_ = error_render_level; - n->sampler_.Seed(seed); + if (seed == -1) seed = std::random_device()(); + tir::RandEngine(&n->rand_state_).Seed(seed); n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); return Schedule(std::move(n)); } -Schedule TracedScheduleNode::Copy(int64_t new_seed) const { +Schedule TracedScheduleNode::Copy(tir::TRandState new_seed) const { ObjectPtr n = make_object(); ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; n->analyzer_ = std::make_unique(); - n->sampler_.Seed(new_seed); + if (new_seed == -1) new_seed = std::random_device()(); + tir::RandEngine(&n->rand_state_).Seed(new_seed); n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); return Schedule(std::move(n)); } @@ -48,8 +50,9 @@ Schedule TracedScheduleNode::Copy(int64_t new_seed) const { Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision) { - Array results = CreateRV(tir::SamplePerfectTile( - this->state_, &this->sampler_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); + Array results = + CreateRV(tir::SamplePerfectTile(this->state_, &this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision)); static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -63,8 +66,8 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { - ExprRV result = - CreateRV(tir::SampleCategorical(this->state_, &this->sampler_, candidates, probs, &decision)); + ExprRV result = CreateRV( + tir::SampleCategorical(this->state_, &this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -77,7 +80,7 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, Optional decision) { - LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->sampler_, + LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation"); @@ -450,7 +453,7 @@ void TracedScheduleNode::InlineArgument(int i, const String& func_name) { TVM_REGISTER_NODE_TYPE(TracedScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") - .set_body_typed([](IRModule mod, int64_t seed, int debug_mode, + .set_body_typed([](IRModule mod, tir::TRandState seed, int debug_mode, int error_render_level) -> Schedule { return Schedule::Traced(mod, seed, debug_mode, static_cast(error_render_level)); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ff362efa99..a0e7c714c6 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -34,7 +34,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void VisitAttrs(tvm::AttrVisitor* v) { // `state_` is not visited // `error_render_level_` is not visited - // `sampler_` is not visited + // `rand_state_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visitied // `trace_` is not visited @@ -47,7 +47,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: Optional trace() const final { return trace_; } - Schedule Copy(int64_t new_seed = -1) const final; + Schedule Copy(tir::TRandState new_seed = -1) const final; public: /******** Schedule: Sampling ********/ diff --git a/tests/cpp/meta_schedule_test.cc b/tests/cpp/meta_schedule_test.cc new file mode 100644 index 0000000000..a01cbfef14 --- /dev/null +++ b/tests/cpp/meta_schedule_test.cc @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "../../../src/tir/schedule/primitive.h" + +TEST(Simplify, Sampling) { + int64_t current = 100; + for (int i = 0; i < 10; i++) { + tvm::tir::SampleInt(¤t, 0, 100); + tvm::tir::SampleUniform(¤t, 3, -1, 0); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/cpp/random_engine_test.cc b/tests/cpp/random_engine_test.cc new file mode 100644 index 0000000000..10b8afa0ee --- /dev/null +++ b/tests/cpp/random_engine_test.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +TEST(RandomEngine, Randomness) { + int64_t rand_state = 0; + + tvm::support::LinearCongruentialEngine rng(&rand_state); + rng.Seed(0x114514); + + bool covered[100]; + memset(covered, 0, sizeof(covered)); + for (int i = 0; i < 100000; i++) { + covered[rng() % 100] = true; + } + for (int i = 0; i < 100; i++) { + ICHECK(covered[i]); + } +} + +TEST(RandomEngine, Reproducibility) { + int64_t rand_state_a = 0, rand_state_b = 0; + tvm::support::LinearCongruentialEngine rng_a(&rand_state_a), rng_b(&rand_state_b); + + rng_a.Seed(0x23456789); + rng_b.Seed(0x23456789); + + for (int i = 0; i < 100000; i++) { + ICHECK_EQ(rng_a(), rng_b()); + } +} + +TEST(RandomEngine, Serialization) { + int64_t rand_state_a = 0, rand_state_b = 0; + tvm::support::LinearCongruentialEngine rng_a(&rand_state_a), rng_b(&rand_state_b); + + rng_a.Seed(0x56728); + + rand_state_b = rand_state_a; + for (int i = 0; i < 100000; i++) ICHECK_EQ(rng_a(), rng_b()); + + for (int i = 0; i < 123456; i++) rng_a(); + + rand_state_b = rand_state_a; + for (int i = 0; i < 100000; i++) ICHECK_EQ(rng_a(), rng_b()); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py b/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py index 8866d0d6ad..55d644d4c8 100644 --- a/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py +++ b/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py @@ -143,7 +143,7 @@ def test_sparse_dense(): print("M =", M, "N =", N, "K =", K, "BS_R =", BS_R, "BS_C = ", BS_C) def check_device(device): - ctx = tvm.context(device, 0) + ctx = tvm.device(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return @@ -154,12 +154,12 @@ def check_device(device): Y = fcompute(X, W_data, W_indices, W_indptr) s = fschedule([Y]) func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), device=ctx) func( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), + tvm.nd.array(X_np, device=ctx), + tvm.nd.array(W_sp_np.data, device=ctx), + tvm.nd.array(W_sp_np.indices, device=ctx), + tvm.nd.array(W_sp_np.indptr, device=ctx), Y_tvm, ) tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) @@ -168,10 +168,10 @@ def check_device(device): "sparse dense te schedule: %f ms" % ( evaluator( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), + tvm.nd.array(X_np, device=ctx), + tvm.nd.array(W_sp_np.data, device=ctx), + tvm.nd.array(W_sp_np.indices, device=ctx), + tvm.nd.array(W_sp_np.indptr, device=ctx), Y_tvm, ).mean * 1e3 @@ -187,23 +187,23 @@ def check_device(device): func = func.specialize(N_blocks, N // BS_R).remove_const_param(N_blocks) def f_create_args(ctx): - X = tvm.nd.array(X_np, ctx=ctx) - W_data = tvm.nd.array(W_sp_np.data, ctx=ctx) - W_indices = tvm.nd.array(W_sp_np.indices, ctx=ctx) - W_indptr = tvm.nd.array(W_sp_np.indptr, ctx=ctx) - Y = tvm.nd.array(Y_np, ctx=ctx) + X = tvm.nd.array(X_np, device=ctx) + W_data = tvm.nd.array(W_sp_np.data, device=ctx) + W_indices = tvm.nd.array(W_sp_np.indices, device=ctx) + W_indptr = tvm.nd.array(W_sp_np.indptr, device=ctx) + Y = tvm.nd.array(Y_np, device=ctx) return [X, W_data, W_indices, W_indptr, Y] sch = meta_schedule_sparse_dense_llvm(func, f_create_args) func = sch.mod func = tvm.build(func) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), device=ctx) func( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), + tvm.nd.array(X_np, device=ctx), + tvm.nd.array(W_sp_np.data, device=ctx), + tvm.nd.array(W_sp_np.indices, device=ctx), + tvm.nd.array(W_sp_np.indptr, device=ctx), Y_tvm, ) tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) @@ -212,10 +212,10 @@ def f_create_args(ctx): "sparse dense auto tir schedule: %f ms" % ( evaluator( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), + tvm.nd.array(X_np, device=ctx), + tvm.nd.array(W_sp_np.data, device=ctx), + tvm.nd.array(W_sp_np.indices, device=ctx), + tvm.nd.array(W_sp_np.indptr, device=ctx), Y_tvm, ).mean * 1e3 diff --git a/tests/python/meta_schedule/test_meta_schedule_feature.py b/tests/python/meta_schedule/test_meta_schedule_feature.py index 57b0edefe8..49ead3062c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_feature.py +++ b/tests/python/meta_schedule/test_meta_schedule_feature.py @@ -675,7 +675,7 @@ def _check_compute(feature): 1, 0, 0, 29, 20, 23, 14, 1, 0, 0, 18, 20.005626, 4.0874629, 25, 16, 19, 10.0014086, 1, ], write_feature=[ - 0, 1, 0, 29, 12.000352, 23, 9.002815, 1, 0, 0, 10.001408, 13.000176, 8.005625, 21, 4.087463, 15, + 0, 1, 0, 29, 12.000352, 23, 9.002815, 1, 0, 0, 10.001408, 13.000176, 8.005625, 21, 4.087463, 15, #pylint: disable=line-too-long 1.584963, 1, ], # fmt: on diff --git a/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py b/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py index 78061c722d..0b56143fab 100644 --- a/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py +++ b/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py @@ -200,9 +200,7 @@ def run_module(lib, use_arm): lib.export_library(tmp.relpath(filename)) # Upload module to device print("Upload...") - remote = auto_scheduler.utils.request_remote( - RPC_KEY, "172.16.2.241", 4445, timeout=10000 - ) + remote = auto_scheduler.utils.request_remote(RPC_KEY, "localhost", 4728, timeout=10000) remote.upload(tmp.relpath(filename)) rlib = remote.load_module(filename) diff --git a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py index 71a0f649e4..b8b94ddbc2 100644 --- a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py +++ b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py @@ -177,12 +177,14 @@ def test_end_to_end_resnet(log): measure_callbacks=[ ms.RecordToFile(), ] - ) + ), ) with ms.ApplyHistoryBest(log, SPACE): - with tvm.transform.PassContext(opt_level=3, config={"relay.with_tir_schedule": True, - "relay.backend.use_meta_schedule": True}): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.with_tir_schedule": True, "relay.backend.use_meta_schedule": True}, + ): lib = relay.build_module.build(mod, target, params=params) def run_module(lib): @@ -195,7 +197,10 @@ def run_module(lib): print("Evaluate inference time cost...") ftimer = module.module.time_evaluator("run", ctx, repeat=3, min_repeat_ms=500) prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond - print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) + print( + "Mean inference time (std dev): %.2f ms (%.2f ms)" + % (np.mean(prof_res), np.std(prof_res)) + ) module.run() return module.get_output(0)