From b39c297d025d9254fd37cc295880eaf7bd752eda Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Jul 2021 12:08:33 -0700 Subject: [PATCH 01/23] Add sampler & rng. --- src/meta_schedule/sampler.h | 424 ++++++++++++++++++++++++++++++++ src/support/rng.h | 113 +++++++++ tests/cpp/meta_schedule_test.cc | 37 +++ 3 files changed, 574 insertions(+) create mode 100644 src/meta_schedule/sampler.h create mode 100644 src/support/rng.h create mode 100644 tests/cpp/meta_schedule_test.cc diff --git a/src/meta_schedule/sampler.h b/src/meta_schedule/sampler.h new file mode 100644 index 0000000000..cec355075e --- /dev/null +++ b/src/meta_schedule/sampler.h @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SAMPLER_H_ +#define TVM_META_SCHEDULE_SAMPLER_H_ + +#include + +#include + +#include "../support/rng.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The struct contains a prime table and the function for factorization. */ +struct PrimeTable { + /*! \brief The table contains prime numbers in [2, kMaxPrime) */ + static constexpr const int kMaxPrime = 65536; + /*! \brief The exact number of prime numbers in the table */ + static constexpr const int kNumPrimes = 6542; + /*! + * \brief For each number in [2, kMaxPrime), the index of its min factor. + * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. + */ + int min_factor_idx[kMaxPrime]; + /*! \brief The prime numbers in [2, kMaxPrime) */ + std::vector primes; + /*! + * \brief The power of each prime number. + * pow_table[i, j] stores the result of pow(prime[i], j + 1) + */ + std::vector> pow_tab; + + /*! \brief Get a global instance of the prime table */ + static const PrimeTable* Global() { + static const PrimeTable table; + return &table; + } + + /*! \brief Constructor, pre-computes all info in the prime table */ + PrimeTable() { + constexpr const int64_t int_max = std::numeric_limits::max(); + // Euler's sieve: prime number in linear time + for (int i = 0; i < kMaxPrime; ++i) { + min_factor_idx[i] = -1; + } + primes.reserve(kNumPrimes); + for (int x = 2; x < kMaxPrime; ++x) { + if (min_factor_idx[x] == -1) { + min_factor_idx[x] = primes.size(); + primes.push_back(x); + } + for (size_t i = 0; i < primes.size(); ++i) { + int factor = primes[i]; + int y = x * factor; + if (y >= kMaxPrime) { + break; + } + min_factor_idx[y] = i; + if (x % factor == 0) { + break; + } + } + } + ICHECK_EQ(static_cast(primes.size()), int(kNumPrimes)); + // Calculate the power table for each prime number + pow_tab.reserve(primes.size()); + for (int prime : primes) { + std::vector tab; + tab.reserve(32); + for (int64_t pow = prime; pow <= int_max; pow *= prime) { + tab.push_back(pow); + } + tab.shrink_to_fit(); + pow_tab.emplace_back(std::move(tab)); + } + } + /*! + * \brief Factorize a number n, and return in a cryptic format + * \param n The number to be factorized + * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] + * For each pair (i, j), we define + * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) + * (primes[i], j) if i != -1 + * Then the factorization is + * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) + */ + std::vector> Factorize(int n) const { + std::vector> result; + result.reserve(16); + int i = 0, n_primes = primes.size(); + // Phase 1: n >= kMaxPrime + for (int j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + if (j != 0) { + result.emplace_back(i, j); + } + } + // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number + if (n >= kMaxPrime) { + result.emplace_back(-1, n); + return result; + } + // Phase 2: n < kMaxPrime + for (int j; n > 1;) { + int i = min_factor_idx[n]; + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + result.emplace_back(i, j); + } + return result; + } +}; + +/*! \brief Sampler based on random number generator for sampling functions used in meta schedule. + * Typical usage is like Sampler(&random_state).SamplingFunc(...). */ +class Sampler { + public: + /*! \brief Return a reproducible random state value that can be used as seed for new samplers. + * \return The random state value to be used as seed for new samplers. + */ + int64_t ForkSeed() { + // In order for reproducibility, we computer the new seed using sampler's RNG's current random + // state and a different set of multiplier & modulus. + // Note that 32767 & 1999999973 are prime numbers. + int64_t ret = (this->rand_.random_state() * 32767) % 1999999973; + this->rand_.next_state(); + return ret; + } + /*! \brief Re-seed the random number generator + * \param seed The value given used to re-seed the RNG. + */ + void Seed(support::RandomNumberGenerator::result_type seed) { this->rand_.seed(seed); } + /*! + * \brief Sample an integer in [min_inclusive, max_exclusive) + * \param min_inclusive The left boundary, inclusive + * \param max_exclusive The right boundary, exclusive + * \return The integer sampled + */ + int SampleInt(int min_inclusive, int max_exclusive) { + if (min_inclusive + 1 == max_exclusive) { + return min_inclusive; + } + std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); + return dist(rand_); + } + /*! + * \brief Sample n integers in [min_inclusive, max_exclusive) + * \param min_inclusive The left boundary, inclusive + * \param max_exclusive The right boundary, exclusive + * \return The list of integers sampled + */ + std::vector SampleInts(int n, int min_inclusive, int max_exclusive) { + std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); + std::vector result; + result.reserve(n); + for (int i = 0; i < n; ++i) { + result.push_back(dist(rand_)); + } + return result; + } + /*! + * \brief Sample n tiling factors of the specific extent + * \param n The number of parts the loop is split + * \param extent Length of the loop + * \param candidates The possible tiling factors + * \return A list of length n, the tiling factors sampled + */ + std::vector SampleTileFactor(int n, int extent, const std::vector& candidates) { + constexpr int kMaxTrials = 100; + std::uniform_int_distribution<> dist(0, static_cast(candidates.size()) - 1); + std::vector sample(n, -1); + for (int trial = 0; trial < kMaxTrials; ++trial) { + int64_t product = 1; + for (int i = 1; i < n; ++i) { + int value = candidates[dist(rand_)]; + product *= value; + if (product > extent) { + break; + } + sample[i] = value; + } + if (product <= extent) { + sample[0] = (extent + product - 1) / product; + return sample; + } + } + sample[0] = extent; + for (int i = 1; i < n; ++i) { + sample[i] = 1; + } + return sample; + } + /*! + * \brief Sample perfect tiling factor of the specific extent + * \param n_splits The number of parts the loop is split + * \param extent Length of the loop + * \return A list of length n_splits, the tiling factors sampled, the product of which strictly + * equals to extent + */ + std::vector SamplePerfectTile(int n_splits, int extent) { + CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; + CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; + // Handle special case that we can potentially accelerate + if (n_splits == 1) { + return {extent}; + } + if (extent == 1) { + return std::vector(n_splits, 1); + } + // Enumerate each pair (i, j), we define + // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) + // (primes[i], j) if i != -1 + // Then the factorization is + // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) + const PrimeTable* prime_tab = PrimeTable::Global(); + std::vector> factorized = prime_tab->Factorize(extent); + if (n_splits == 2) { + // n_splits = 2, this can be taken special care of, + // because general reservoir sampling can be avoided to accelerate the sampling + int result0 = 1; + int result1 = 1; + for (const std::pair& ij : factorized) { + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + (SampleInt(0, 2) ? result1 : result0) *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int p = ij.second; + const int* pow = prime_tab->pow_tab[ij.first].data() - 1; + int x1 = SampleInt(0, p + 1); + int x2 = p - x1; + if (x1 != 0) { + result0 *= pow[x1]; + } + if (x2 != 0) { + result1 *= pow[x2]; + } + } + return {result0, result1}; + } + // Data range: + // 2 <= extent <= 2^31 - 1 + // 3 <= n_splits <= max tiling splits + // 1 <= p <= 31 + std::vector result(n_splits, 1); + for (const std::pair& ij : factorized) { + // Handle special cases to accelerate sampling + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + result[SampleInt(0, n_splits)] *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int p = ij.second; + if (p == 1) { + result[SampleInt(0, n_splits)] *= prime_tab->primes[ij.first]; + continue; + } + // The general case. We have to sample uniformly from the solution of: + // x_1 + x_2 + ... + x_{n_splits} = p + // where x_i >= 0 + // Data range: + // 2 <= p <= 31 + // 3 <= n_splits <= max tiling splits + std::vector sampled = SampleWithoutReplacement(p + n_splits - 1, n_splits - 1); + std::sort(sampled.begin(), sampled.end()); + sampled.push_back(p + n_splits - 1); + const int* pow = prime_tab->pow_tab[ij.first].data() - 1; + for (int i = 0, last = -1; i < n_splits; ++i) { + int x = sampled[i] - last - 1; + last = sampled[i]; + if (x != 0) { + result[i] *= pow[x]; + } + } + } + return result; + } + /*! + * \brief Sample perfect tiling factor of the specific extent + * \param n_splits The number of parts the loop is split + * \param extent Length of the loop + * \param max_innermost_factor A small number indicating the max length of the innermost loop + * \return A list of length n_splits, the tiling factors sampled, the product of which strictly + * equals to extent + */ + std::vector SamplePerfectTile(int n_splits, int extent, int max_innermost_factor) { + if (max_innermost_factor == -1) { + return this->SamplePerfectTile(n_splits, extent); + } + CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; + std::vector innermost_candidates; + innermost_candidates.reserve(max_innermost_factor); + for (int i = 1; i <= max_innermost_factor; ++i) { + if (extent % i == 0) { + innermost_candidates.push_back(i); + } + } + // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. + // We should do multiple factorization to weight the choices. However, it would lead to slower + // sampling speed. On the other hand, considering potential tricks we might do on the innermost + // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe + // add more heuristics in the future + int innermost = innermost_candidates[SampleInt(0, innermost_candidates.size())]; + std::vector result = SamplePerfectTile(n_splits - 1, extent / innermost); + result.push_back(innermost); + return result; + } + /*! + * \brief Sample n floats uniformly in [min, max) + * \param min The left boundary + * \param max The right boundary + * \return The list of floats sampled + */ + std::vector SampleUniform(int n, double min, double max) { + std::uniform_real_distribution dist(min, max); + std::vector result; + result.reserve(n); + for (int i = 0; i < n; ++i) { + result.push_back(dist(rand_)); + } + return result; + } + /*! + * \brief Sample from a Bernoulli distribution + * \param p Parameter in the Bernoulli distribution + * \return return true with probability p, and false with probability (1 - p) + */ + bool SampleBernoulli(double p) { + std::bernoulli_distribution dist(p); + return dist(rand_); + } + /*! + * \brief Create a multinomial sampler based on the specific weights + * \param weights The weights, event probabilities + * \return The multinomial sampler + */ + std::function MakeMultinomial(const std::vector& weights) { + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + std::uniform_real_distribution dist(0.0, sum); + auto sampler = [this, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { + double p = dist(rand_); + int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; + return sampler; + } + /*! + * \brief Classic sampling without replacement + * \param n The population size + * \param k The number of samples to be drawn from the population + * \return A list of indices, samples drawn, unsorted and index starting from 0 + */ + std::vector SampleWithoutReplacement(int n, int k) { + if (k == 1) { + return {SampleInt(0, n)}; + } + if (k == 2) { + int result0 = SampleInt(0, n); + int result1 = SampleInt(0, n - 1); + if (result1 >= result0) { + result1 += 1; + } + return {result0, result1}; + } + std::vector order(n); + for (int i = 0; i < n; ++i) { + order[i] = i; + } + for (int i = 0; i < k; ++i) { + int j = SampleInt(i, n); + if (i != j) { + std::swap(order[i], order[j]); + } + } + return {order.begin(), order.begin() + k}; + } + + /*! \brief The default constructor function for Sampler */ + Sampler() = default; + /*! + * \brief Constructor. Construct a sampler with a given random state pointer for its RNG. + * \param random_state The given pointer to random state used to construct the RNG. + * \note The random state is neither initialized not modified by this constructor. + */ + explicit Sampler(support::RandomNumberGenerator::result_type* random_state) + : rand_(random_state) {} + + private: + /*! \brief The random number generator for sampling. */ + support::RandomNumberGenerator rand_; +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SAMPLER_H_ \ No newline at end of file diff --git a/src/support/rng.h b/src/support/rng.h new file mode 100644 index 0000000000..7b4123d9d0 --- /dev/null +++ b/src/support/rng.h @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rng.h + * \brief Random number generator, for Sampler and Sampling + * functions. + */ + +#ifndef TVM_SUPPORT_RNG_H_ +#define TVM_SUPPORT_RNG_H_ + +#include + +#include // for int64_t + +namespace tvm { +namespace support { + +/*! + * \brief The random number generator is implemented as a linear congruential engine. + */ +class RandomNumberGenerator { + public: + /*! \brief The result type is defined as int64_t here for sampler usage. */ + using result_type = int64_t; + + /*! \brief The multiplier */ + static constexpr result_type multiplier = 48271; + + /*! \brief The increment */ + static constexpr result_type increment = 0; + + /*! \brief The modulus */ + static constexpr result_type modulus = 2147483647; + + /*! \brief Construct a null random number generator. */ + RandomNumberGenerator() { rand_state_ptr = nullptr; } + + /*! + * \brief Construct a random number generator with a random state pointer. + * \param random_state The random state pointer given in result_type*. + * \note The random state is not initialized here. You may need to call seed function. + */ + explicit RandomNumberGenerator(result_type* random_state) { rand_state_ptr = random_state; } + + /*! + * \brief Change the start random state of RNG with the seed of a new random state value. + * \param random_state The random state given in result_type. + * \note The seed is used to initialize the random number generator and the random state would be + * changed to next random state by calling the next_state() function. + */ + void seed(result_type state = 1) { + state %= modulus; // Make sure the seed is within the range of the modulus. + if (state < 0) state += modulus; // The congruential engine is always non-negative. + ICHECK(rand_state_ptr != nullptr); // Make sure the pointer is not null. + *rand_state_ptr = state; // Change pointed random state to given random state value. + next_state(); + }; + + /*! \brief The minimum possible value of random state here. */ + result_type min() { return 0; } + + /*! \brief The maximum possible value of random state here. */ + result_type max() { return modulus - 1; } + + /*! + * \brief Fetch the current random state. + * \return The current random state value in the type of result_type. + */ + result_type random_state() { return *rand_state_ptr; } + + /*! + * \brief Operator to fetch the current random state. + * \return The current random state value in the type of result_type. + */ + result_type operator()() { return next_state(); } + + /*! + * \brief Move the random state to the next and return the new random state. According to + * definition of linear congruential engine, the new random state value is computed as + * new_random_state = (current_random_state * multiplier + increment) % modulus. + * \return The next current random state value in the type of result_type. + */ + result_type next_state() { + (*rand_state_ptr) = ((*rand_state_ptr) * multiplier + increment) % modulus; + return *rand_state_ptr; + } + + private: + result_type* rand_state_ptr; +}; + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_RNG_H_ diff --git a/tests/cpp/meta_schedule_test.cc b/tests/cpp/meta_schedule_test.cc new file mode 100644 index 0000000000..e65621a563 --- /dev/null +++ b/tests/cpp/meta_schedule_test.cc @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "../../../src/meta_schedule/sampler.h" + +TEST(Simplify, Sampler) { + int64_t current = 100; + for (int i = 0; i < 10; i++) { + tvm::meta_schedule::Sampler(¤t).SampleInt(0, 100); + tvm::meta_schedule::Sampler(¤t).SampleUniform(3, -1, 0); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} \ No newline at end of file From 38aa8adb24f4f942d28aa7f24adb3b5577662870 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Jul 2021 17:17:34 -0700 Subject: [PATCH 02/23] Make Sampler work with current implementation. --- src/meta_schedule/autotune.cc | 4 +- src/meta_schedule/autotune.h | 3 +- .../cost_model/rand_cost_model.cc | 8 +- src/meta_schedule/sampler.h | 424 ------------------ src/meta_schedule/search.cc | 33 +- src/meta_schedule/search.h | 13 +- src/meta_schedule/space/post_order_apply.cc | 29 +- src/meta_schedule/space/postproc.cc | 31 +- src/meta_schedule/space/postproc.h | 4 +- src/meta_schedule/space/schedule_fn.cc | 23 +- src/meta_schedule/space/search_rule.cc | 2 +- src/meta_schedule/strategy/evolutionary.cc | 129 +++--- src/meta_schedule/strategy/mutator.cc | 62 +-- src/meta_schedule/strategy/mutator.h | 5 +- src/meta_schedule/strategy/replay.cc | 20 +- src/support/rng.h | 1 + src/tir/schedule/concrete_schedule.cc | 10 +- src/tir/schedule/concrete_schedule.h | 14 +- src/tir/schedule/primitive.h | 9 +- src/tir/schedule/primitive/sampling.cc | 13 +- src/tir/schedule/sampler.cc | 239 +++++----- src/tir/schedule/sampler.h | 45 +- src/tir/schedule/traced_schedule.cc | 19 +- src/tir/schedule/traced_schedule.h | 2 +- tests/cpp/meta_schedule_test.cc | 6 +- 25 files changed, 362 insertions(+), 786 deletions(-) delete mode 100644 src/meta_schedule/sampler.h diff --git a/src/meta_schedule/autotune.cc b/src/meta_schedule/autotune.cc index a2ce7fb4a0..24d2a1bf39 100644 --- a/src/meta_schedule/autotune.cc +++ b/src/meta_schedule/autotune.cc @@ -25,7 +25,7 @@ namespace meta_schedule { void TuneContextNode::Init(Optional seed) { if (seed.defined()) { - this->sampler.Seed(seed.value()->value); + Sampler(&this->rand_state).Seed(seed.value()->value); } if (task.defined()) { task.value()->Init(this); @@ -59,7 +59,7 @@ void TuneContextNode::Init(Optional seed) { bool TuneContextNode::Postprocess(const Schedule& sch) { sch->EnterPostproc(); for (const Postproc& postproc : postprocs) { - if (!postproc->Apply(task.value(), sch, &sampler)) { + if (!postproc->Apply(task.value(), sch, &rand_state)) { return false; } } diff --git a/src/meta_schedule/autotune.h b/src/meta_schedule/autotune.h index dc7c391f50..3196bb8f72 100644 --- a/src/meta_schedule/autotune.h +++ b/src/meta_schedule/autotune.h @@ -43,7 +43,8 @@ class TuneContextNode : public runtime::Object { Array postprocs; Array measure_callbacks; int num_threads; - Sampler sampler; + + Sampler::TRandomState rand_state; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("task", &task); diff --git a/src/meta_schedule/cost_model/rand_cost_model.cc b/src/meta_schedule/cost_model/rand_cost_model.cc index 697e9aefe2..92b78fa078 100644 --- a/src/meta_schedule/cost_model/rand_cost_model.cc +++ b/src/meta_schedule/cost_model/rand_cost_model.cc @@ -27,8 +27,8 @@ namespace meta_schedule { /*! \brief The cost model returning random value for all predictions */ class RandCostModelNode : public CostModelNode { public: - /*! \brief A sampler for generating random numbers */ - Sampler sampler; + /*! \brief A random state for sampler to generate random numbers */ + Sampler::TRandomState rand_state; void VisitAttrs(tvm::AttrVisitor* v) { // sampler is not visited @@ -48,7 +48,7 @@ class RandCostModelNode : public CostModelNode { * \return The predicted scores for all states */ std::vector Predict(const SearchTask& task, const Array& states) override { - return sampler.SampleUniform(states.size(), 0.0, 1.0); + return Sampler(&rand_state).SampleUniform(states.size(), 0.0, 1.0); } static constexpr const char* _type_key = "meta_schedule.RandCostModel"; @@ -65,7 +65,7 @@ class RandCostModel : public CostModel { explicit RandCostModel(int seed) { ObjectPtr n = make_object(); - n->sampler.Seed(seed); + Sampler(&n->rand_state).Seed(seed); data_ = std::move(n); } diff --git a/src/meta_schedule/sampler.h b/src/meta_schedule/sampler.h deleted file mode 100644 index cec355075e..0000000000 --- a/src/meta_schedule/sampler.h +++ /dev/null @@ -1,424 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_META_SCHEDULE_SAMPLER_H_ -#define TVM_META_SCHEDULE_SAMPLER_H_ - -#include - -#include - -#include "../support/rng.h" - -namespace tvm { -namespace meta_schedule { - -/*! \brief The struct contains a prime table and the function for factorization. */ -struct PrimeTable { - /*! \brief The table contains prime numbers in [2, kMaxPrime) */ - static constexpr const int kMaxPrime = 65536; - /*! \brief The exact number of prime numbers in the table */ - static constexpr const int kNumPrimes = 6542; - /*! - * \brief For each number in [2, kMaxPrime), the index of its min factor. - * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. - */ - int min_factor_idx[kMaxPrime]; - /*! \brief The prime numbers in [2, kMaxPrime) */ - std::vector primes; - /*! - * \brief The power of each prime number. - * pow_table[i, j] stores the result of pow(prime[i], j + 1) - */ - std::vector> pow_tab; - - /*! \brief Get a global instance of the prime table */ - static const PrimeTable* Global() { - static const PrimeTable table; - return &table; - } - - /*! \brief Constructor, pre-computes all info in the prime table */ - PrimeTable() { - constexpr const int64_t int_max = std::numeric_limits::max(); - // Euler's sieve: prime number in linear time - for (int i = 0; i < kMaxPrime; ++i) { - min_factor_idx[i] = -1; - } - primes.reserve(kNumPrimes); - for (int x = 2; x < kMaxPrime; ++x) { - if (min_factor_idx[x] == -1) { - min_factor_idx[x] = primes.size(); - primes.push_back(x); - } - for (size_t i = 0; i < primes.size(); ++i) { - int factor = primes[i]; - int y = x * factor; - if (y >= kMaxPrime) { - break; - } - min_factor_idx[y] = i; - if (x % factor == 0) { - break; - } - } - } - ICHECK_EQ(static_cast(primes.size()), int(kNumPrimes)); - // Calculate the power table for each prime number - pow_tab.reserve(primes.size()); - for (int prime : primes) { - std::vector tab; - tab.reserve(32); - for (int64_t pow = prime; pow <= int_max; pow *= prime) { - tab.push_back(pow); - } - tab.shrink_to_fit(); - pow_tab.emplace_back(std::move(tab)); - } - } - /*! - * \brief Factorize a number n, and return in a cryptic format - * \param n The number to be factorized - * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] - * For each pair (i, j), we define - * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) - * (primes[i], j) if i != -1 - * Then the factorization is - * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) - */ - std::vector> Factorize(int n) const { - std::vector> result; - result.reserve(16); - int i = 0, n_primes = primes.size(); - // Phase 1: n >= kMaxPrime - for (int j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { - for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { - } - if (j != 0) { - result.emplace_back(i, j); - } - } - // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number - if (n >= kMaxPrime) { - result.emplace_back(-1, n); - return result; - } - // Phase 2: n < kMaxPrime - for (int j; n > 1;) { - int i = min_factor_idx[n]; - for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { - } - result.emplace_back(i, j); - } - return result; - } -}; - -/*! \brief Sampler based on random number generator for sampling functions used in meta schedule. - * Typical usage is like Sampler(&random_state).SamplingFunc(...). */ -class Sampler { - public: - /*! \brief Return a reproducible random state value that can be used as seed for new samplers. - * \return The random state value to be used as seed for new samplers. - */ - int64_t ForkSeed() { - // In order for reproducibility, we computer the new seed using sampler's RNG's current random - // state and a different set of multiplier & modulus. - // Note that 32767 & 1999999973 are prime numbers. - int64_t ret = (this->rand_.random_state() * 32767) % 1999999973; - this->rand_.next_state(); - return ret; - } - /*! \brief Re-seed the random number generator - * \param seed The value given used to re-seed the RNG. - */ - void Seed(support::RandomNumberGenerator::result_type seed) { this->rand_.seed(seed); } - /*! - * \brief Sample an integer in [min_inclusive, max_exclusive) - * \param min_inclusive The left boundary, inclusive - * \param max_exclusive The right boundary, exclusive - * \return The integer sampled - */ - int SampleInt(int min_inclusive, int max_exclusive) { - if (min_inclusive + 1 == max_exclusive) { - return min_inclusive; - } - std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); - return dist(rand_); - } - /*! - * \brief Sample n integers in [min_inclusive, max_exclusive) - * \param min_inclusive The left boundary, inclusive - * \param max_exclusive The right boundary, exclusive - * \return The list of integers sampled - */ - std::vector SampleInts(int n, int min_inclusive, int max_exclusive) { - std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); - std::vector result; - result.reserve(n); - for (int i = 0; i < n; ++i) { - result.push_back(dist(rand_)); - } - return result; - } - /*! - * \brief Sample n tiling factors of the specific extent - * \param n The number of parts the loop is split - * \param extent Length of the loop - * \param candidates The possible tiling factors - * \return A list of length n, the tiling factors sampled - */ - std::vector SampleTileFactor(int n, int extent, const std::vector& candidates) { - constexpr int kMaxTrials = 100; - std::uniform_int_distribution<> dist(0, static_cast(candidates.size()) - 1); - std::vector sample(n, -1); - for (int trial = 0; trial < kMaxTrials; ++trial) { - int64_t product = 1; - for (int i = 1; i < n; ++i) { - int value = candidates[dist(rand_)]; - product *= value; - if (product > extent) { - break; - } - sample[i] = value; - } - if (product <= extent) { - sample[0] = (extent + product - 1) / product; - return sample; - } - } - sample[0] = extent; - for (int i = 1; i < n; ++i) { - sample[i] = 1; - } - return sample; - } - /*! - * \brief Sample perfect tiling factor of the specific extent - * \param n_splits The number of parts the loop is split - * \param extent Length of the loop - * \return A list of length n_splits, the tiling factors sampled, the product of which strictly - * equals to extent - */ - std::vector SamplePerfectTile(int n_splits, int extent) { - CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; - CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; - // Handle special case that we can potentially accelerate - if (n_splits == 1) { - return {extent}; - } - if (extent == 1) { - return std::vector(n_splits, 1); - } - // Enumerate each pair (i, j), we define - // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) - // (primes[i], j) if i != -1 - // Then the factorization is - // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) - const PrimeTable* prime_tab = PrimeTable::Global(); - std::vector> factorized = prime_tab->Factorize(extent); - if (n_splits == 2) { - // n_splits = 2, this can be taken special care of, - // because general reservoir sampling can be avoided to accelerate the sampling - int result0 = 1; - int result1 = 1; - for (const std::pair& ij : factorized) { - // Case 1: (a, p) = (j, 1), where j is a prime number - if (ij.first == -1) { - (SampleInt(0, 2) ? result1 : result0) *= ij.second; - continue; - } - // Case 2: (a = primes[i], p = 1) - int p = ij.second; - const int* pow = prime_tab->pow_tab[ij.first].data() - 1; - int x1 = SampleInt(0, p + 1); - int x2 = p - x1; - if (x1 != 0) { - result0 *= pow[x1]; - } - if (x2 != 0) { - result1 *= pow[x2]; - } - } - return {result0, result1}; - } - // Data range: - // 2 <= extent <= 2^31 - 1 - // 3 <= n_splits <= max tiling splits - // 1 <= p <= 31 - std::vector result(n_splits, 1); - for (const std::pair& ij : factorized) { - // Handle special cases to accelerate sampling - // Case 1: (a, p) = (j, 1), where j is a prime number - if (ij.first == -1) { - result[SampleInt(0, n_splits)] *= ij.second; - continue; - } - // Case 2: (a = primes[i], p = 1) - int p = ij.second; - if (p == 1) { - result[SampleInt(0, n_splits)] *= prime_tab->primes[ij.first]; - continue; - } - // The general case. We have to sample uniformly from the solution of: - // x_1 + x_2 + ... + x_{n_splits} = p - // where x_i >= 0 - // Data range: - // 2 <= p <= 31 - // 3 <= n_splits <= max tiling splits - std::vector sampled = SampleWithoutReplacement(p + n_splits - 1, n_splits - 1); - std::sort(sampled.begin(), sampled.end()); - sampled.push_back(p + n_splits - 1); - const int* pow = prime_tab->pow_tab[ij.first].data() - 1; - for (int i = 0, last = -1; i < n_splits; ++i) { - int x = sampled[i] - last - 1; - last = sampled[i]; - if (x != 0) { - result[i] *= pow[x]; - } - } - } - return result; - } - /*! - * \brief Sample perfect tiling factor of the specific extent - * \param n_splits The number of parts the loop is split - * \param extent Length of the loop - * \param max_innermost_factor A small number indicating the max length of the innermost loop - * \return A list of length n_splits, the tiling factors sampled, the product of which strictly - * equals to extent - */ - std::vector SamplePerfectTile(int n_splits, int extent, int max_innermost_factor) { - if (max_innermost_factor == -1) { - return this->SamplePerfectTile(n_splits, extent); - } - CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; - std::vector innermost_candidates; - innermost_candidates.reserve(max_innermost_factor); - for (int i = 1; i <= max_innermost_factor; ++i) { - if (extent % i == 0) { - innermost_candidates.push_back(i); - } - } - // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. - // We should do multiple factorization to weight the choices. However, it would lead to slower - // sampling speed. On the other hand, considering potential tricks we might do on the innermost - // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe - // add more heuristics in the future - int innermost = innermost_candidates[SampleInt(0, innermost_candidates.size())]; - std::vector result = SamplePerfectTile(n_splits - 1, extent / innermost); - result.push_back(innermost); - return result; - } - /*! - * \brief Sample n floats uniformly in [min, max) - * \param min The left boundary - * \param max The right boundary - * \return The list of floats sampled - */ - std::vector SampleUniform(int n, double min, double max) { - std::uniform_real_distribution dist(min, max); - std::vector result; - result.reserve(n); - for (int i = 0; i < n; ++i) { - result.push_back(dist(rand_)); - } - return result; - } - /*! - * \brief Sample from a Bernoulli distribution - * \param p Parameter in the Bernoulli distribution - * \return return true with probability p, and false with probability (1 - p) - */ - bool SampleBernoulli(double p) { - std::bernoulli_distribution dist(p); - return dist(rand_); - } - /*! - * \brief Create a multinomial sampler based on the specific weights - * \param weights The weights, event probabilities - * \return The multinomial sampler - */ - std::function MakeMultinomial(const std::vector& weights) { - std::vector sums; - sums.reserve(weights.size()); - double sum = 0.0; - for (double w : weights) { - sums.push_back(sum += w); - } - std::uniform_real_distribution dist(0.0, sum); - auto sampler = [this, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { - double p = dist(rand_); - int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); - int n = sums.size(); - CHECK_LE(0, idx); - CHECK_LE(idx, n); - return (idx == n) ? (n - 1) : idx; - }; - return sampler; - } - /*! - * \brief Classic sampling without replacement - * \param n The population size - * \param k The number of samples to be drawn from the population - * \return A list of indices, samples drawn, unsorted and index starting from 0 - */ - std::vector SampleWithoutReplacement(int n, int k) { - if (k == 1) { - return {SampleInt(0, n)}; - } - if (k == 2) { - int result0 = SampleInt(0, n); - int result1 = SampleInt(0, n - 1); - if (result1 >= result0) { - result1 += 1; - } - return {result0, result1}; - } - std::vector order(n); - for (int i = 0; i < n; ++i) { - order[i] = i; - } - for (int i = 0; i < k; ++i) { - int j = SampleInt(i, n); - if (i != j) { - std::swap(order[i], order[j]); - } - } - return {order.begin(), order.begin() + k}; - } - - /*! \brief The default constructor function for Sampler */ - Sampler() = default; - /*! - * \brief Constructor. Construct a sampler with a given random state pointer for its RNG. - * \param random_state The given pointer to random state used to construct the RNG. - * \note The random state is neither initialized not modified by this constructor. - */ - explicit Sampler(support::RandomNumberGenerator::result_type* random_state) - : rand_(random_state) {} - - private: - /*! \brief The random number generator for sampling. */ - support::RandomNumberGenerator rand_; -}; - -} // namespace meta_schedule -} // namespace tvm - -#endif // TVM_META_SCHEDULE_SAMPLER_H_ \ No newline at end of file diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index 75f2b7acdd..83eb73e12f 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -58,17 +58,18 @@ SearchTask::SearchTask(tir::PrimFunc workload, String task_name, Target target, */ TVM_DLL Optional AutoTune(SearchTask task, SearchSpace space, SearchStrategy strategy, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } + if (verbose) { LOG(INFO) << "Tuning for task: " << task; } space->Init(task); strategy->Init(task); measurer->Init(task); - return strategy->Search(task, space, measurer, &seeded, verbose); + return strategy->Search(task, space, measurer, &rand_state, verbose); } /********** Printer **********/ @@ -101,17 +102,17 @@ struct Internal { * \brief Apply postprocessors onto the schedule * \param space The search space * \param sch The schedule to be postprocessed - * \param sampler The random number generator + * \param rand_state The sampler's random state * \return Whether postprocessing has succeeded * \sa SearchSpaceNode::Postprocess */ static bool SearchSpacePostprocess(SearchSpace space, SearchTask task, Schedule sch, Optional seed) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return space->Postprocess(task, sch, &seeded); + return space->Postprocess(task, sch, &rand_state); } /*! * \brief Sample a schedule out of the search space, calls SearchSpaceNode::SampleSchedule @@ -122,11 +123,11 @@ struct Internal { */ static Schedule SearchSpaceSampleSchedule(SearchSpace space, SearchTask task, Optional seed) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return space->SampleSchedule(task, &seeded); + return space->SampleSchedule(task, &rand_state); } /*! * \brief Get support of the search space, calls SearchSpaceNode::GetSupport @@ -138,11 +139,11 @@ struct Internal { */ static Array SearchSpaceGetSupport(SearchSpace space, SearchTask task, Optional seed) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return space->GetSupport(task, &seeded); + return space->GetSupport(task, &rand_state); } /*! * \brief Explore the search space and find the best schedule @@ -156,11 +157,11 @@ struct Internal { static Optional SearchStrategySearch(SearchStrategy strategy, SearchTask task, SearchSpace space, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return strategy->Search(task, space, measurer, &seeded, verbose); + return strategy->Search(task, space, measurer, &rand_state, verbose); } }; diff --git a/src/meta_schedule/search.h b/src/meta_schedule/search.h index 7e2edfc06e..9756cde094 100644 --- a/src/meta_schedule/search.h +++ b/src/meta_schedule/search.h @@ -101,22 +101,23 @@ class SearchSpaceNode : public runtime::Object { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param sampler The random number generator + * \param rand_state The sampler's random state */ - virtual bool Postprocess(const SearchTask& task, const Schedule& sch, Sampler* sampler) = 0; + virtual bool Postprocess(const SearchTask& task, const Schedule& sch, + Sampler::TRandomState* rand_state) = 0; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - virtual Schedule SampleSchedule(const SearchTask& task, Sampler* sampler) = 0; + virtual Schedule SampleSchedule(const SearchTask& task, Sampler::TRandomState* rand_state) = 0; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return The support of the search space. Any point from the search space should along to one of * the traces returned */ - virtual Array GetSupport(const SearchTask& task, Sampler* sampler) = 0; + virtual Array GetSupport(const SearchTask& task, Sampler::TRandomState* rand_state) = 0; static constexpr const char* _type_key = "meta_schedule.SearchSpace"; TVM_DECLARE_BASE_OBJECT_INFO(SearchSpaceNode, Object); @@ -156,8 +157,8 @@ class SearchStrategyNode : public Object { * \return The best schedule found, NullOpt if no valid schedule is found */ virtual Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, - int verbose) = 0; + const ProgramMeasurer& measurer, + Sampler::TRandomState* rand_state, int verbose) = 0; /*! \brief Explore the search space */ virtual void Search() { LOG(FATAL) << "NotImplemented"; } diff --git a/src/meta_schedule/space/post_order_apply.cc b/src/meta_schedule/space/post_order_apply.cc index 7d32178893..3ad13c02d8 100644 --- a/src/meta_schedule/space/post_order_apply.cc +++ b/src/meta_schedule/space/post_order_apply.cc @@ -49,22 +49,23 @@ class PostOrderApplyNode : public SearchSpaceNode { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param sampler The random number generator + * \param rand_state The sampler's random state */ - bool Postprocess(const SearchTask& task, const Schedule& sch, Sampler* sampler) override; + bool Postprocess(const SearchTask& task, const Schedule& sch, + Sampler::TRandomState* rand_state) override; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - Schedule SampleSchedule(const SearchTask& task, Sampler* sampler) override; + Schedule SampleSchedule(const SearchTask& task, Sampler::TRandomState* rand_state) override; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return An array with a single element returned from SampleSchedule * \sa PostOrderApplyNode::SampleSchedule */ - Array GetSupport(const SearchTask& task, Sampler* sampler) override; + Array GetSupport(const SearchTask& task, Sampler::TRandomState* rand_state) override; static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SearchSpaceNode); @@ -97,20 +98,21 @@ PostOrderApply::PostOrderApply(Array stages, Array postpro /********** Sampling **********/ bool PostOrderApplyNode::Postprocess(const SearchTask& task, const Schedule& sch, - Sampler* sampler) { - sch->EnterPostproc(); + Sampler::TRandomState* rand_state) { + sch->EnterPostProc(); for (const Postproc& postproc : postprocs) { - if (!postproc->Apply(task, sch, sampler)) { + if (!postproc->Apply(task, sch, rand_state)) { return false; } } return true; } -Schedule PostOrderApplyNode::SampleSchedule(const SearchTask& task, Sampler* sampler) { - Array support = GetSupport(task, sampler); +Schedule PostOrderApplyNode::SampleSchedule(const SearchTask& task, + Sampler::TRandomState* rand_state) { + Array support = GetSupport(task, rand_state); ICHECK(!support.empty()) << "ValueError: Found null support"; - int i = sampler->SampleInt(0, support.size()); + int i = Sampler(rand_state).SampleInt(0, support.size()); return support[i]; } @@ -146,12 +148,13 @@ class BlockCollector : public tir::StmtVisitor { const tir::BlockNode* root_block_; }; -Array PostOrderApplyNode::GetSupport(const SearchTask& task, Sampler* sampler) { +Array PostOrderApplyNode::GetSupport(const SearchTask& task, + Sampler::TRandomState* rand_state) { using ScheduleAndUnvisitedBlocks = std::pair>; Array curr{ Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/sampler->ForkSeed(), + /*seed=*/Sampler(rand_state).ForkSeed(), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail)}; for (const SearchRule& rule : stages) { @@ -201,7 +204,7 @@ Array PostOrderApplyNode::GetSupport(const SearchTask& task, Sampler* Trace trace = sch->trace().value()->Simplified(/*remove_postproc=*/true); Schedule new_sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/sampler->ForkSeed(), + /*seed=*/Sampler(rand_state).ForkSeed(), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); trace->ApplyToSchedule(new_sch, /*remove_postproc=*/true); diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index 4e9c229160..e0774a7ce6 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -38,8 +38,9 @@ Postproc::Postproc(String name, FProc proc) { /********** Postproc **********/ -bool PostprocNode::Apply(const SearchTask& task, const Schedule& sch, Sampler* sampler) { - return proc_(task, sch, sampler); +bool PostprocNode::Apply(const SearchTask& task, const Schedule& sch, + Sampler::TRandomState* rand_state) { + return proc_(task, sch, rand_state); } /********** RewriteTensorize **********/ @@ -115,7 +116,7 @@ class PostprocRewriteTensorize { Postproc RewriteTensorize(Array tensor_intrins) { auto f_proc = [tensor_intrins{std::move(tensor_intrins)}](SearchTask task, Schedule self, - void* _sampler) -> bool { + void* _rand_state) -> bool { return PostprocRewriteTensorize(tensor_intrins).Proc(self); }; return Postproc("rewrite_tensorize", f_proc); @@ -181,7 +182,7 @@ class PostprocRewriteCooperativeFetch { }; Postproc RewriteCooperativeFetch() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocRewriteCooperativeFetch().Proc(sch); }; return Postproc("rewrite_cooperative_fetch", f_proc); @@ -498,7 +499,7 @@ class PostprocRewriteParallelizeVectorizeUnroll { }; Postproc RewriteParallelizeVectorizeUnroll() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocRewriteParallelizeVectorizeUnroll().Proc(sch); }; return Postproc("rewrite_parallelize_vectorize_unroll", f_proc); @@ -632,7 +633,7 @@ class PostprocRewriteUnboundBlocks { }; Postproc RewriteUnboundBlocks() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocRewriteUnboundBlocks().Proc(task, sch); }; return Postproc("rewrite_unbound_blocks", f_proc); @@ -764,7 +765,7 @@ class PostprocRewriteReductionBlock { }; Postproc RewriteReductionBlock() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocRewriteReductionBlock().Proc(sch); }; return Postproc("rewrite_reduction_block", f_proc); @@ -794,7 +795,7 @@ class PostprocDisallowDynamicLoops { }; Postproc DisallowDynamicLoops() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocDisallowDynamicLoops().Proc(sch); }; return Postproc("disallow_dynamic_loops", f_proc); @@ -849,7 +850,7 @@ class PostprocVerifyGPUCode { }; Postproc VerifyGPUCode() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostprocVerifyGPUCode().Proc(task, sch); }; return Postproc("verify_gpu_code", f_proc); @@ -1075,8 +1076,8 @@ class PostProcRewriteLayout { } // Step 1: create a new buffer tir::Buffer new_buffer(buffer->data, buffer->dtype, new_shape, Array(), - buffer->elem_offset, buffer->name, - buffer->data_alignment, buffer->offset_factor, buffer->buffer_type); + buffer->elem_offset, buffer->name, buffer->data_alignment, + buffer->offset_factor, buffer->buffer_type); // Step 2: do the rewrite to the buffer access // the rule is as below: // for example, @@ -1104,7 +1105,7 @@ class PostProcRewriteLayout { }; Postproc RewriteLayout() { - auto f_proc = [](SearchTask task, Schedule sch, void* _sampler) -> bool { + auto f_proc = [](SearchTask task, Schedule sch, void* _rand_state) -> bool { return PostProcRewriteLayout().Proc(sch, task); }; return Postproc("rewrite_layout", f_proc); @@ -1118,11 +1119,11 @@ struct Internal { * \sa PostProcNode::Apply */ static bool Apply(Postproc self, SearchTask task, Schedule sch, Optional seed) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return self->Apply(task, sch, &seeded); + return self->Apply(task, sch, &rand_state); } }; diff --git a/src/meta_schedule/space/postproc.h b/src/meta_schedule/space/postproc.h index 2e9bd604fe..d4786b2c32 100644 --- a/src/meta_schedule/space/postproc.h +++ b/src/meta_schedule/space/postproc.h @@ -44,10 +44,10 @@ class PostprocNode : public Object { /*! * \brief Apply the postprocessor * \param sch The schedule to be processed - * \param sampler The random number sampler + * \param rand_state The sampler's random state * \return If the post-processing succeeds */ - bool Apply(const SearchTask& task, const Schedule& sch, Sampler* sampler); + bool Apply(const SearchTask& task, const Schedule& sch, Sampler::TRandomState* rand_state); static constexpr const char* _type_key = "meta_schedule.Postproc"; TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); diff --git a/src/meta_schedule/space/schedule_fn.cc b/src/meta_schedule/space/schedule_fn.cc index 5d056c328f..4097b4b26a 100644 --- a/src/meta_schedule/space/schedule_fn.cc +++ b/src/meta_schedule/space/schedule_fn.cc @@ -47,22 +47,23 @@ class ScheduleFnNode : public SearchSpaceNode { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param sampler The random number generator + * \param rand_state The sampler random state */ - bool Postprocess(const SearchTask& task, const Schedule& sch, Sampler* sampler) override; + bool Postprocess(const SearchTask& task, const Schedule& sch, + Sampler::TRandomState* rand_state) override; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - Schedule SampleSchedule(const SearchTask& task, Sampler* sampler) override; + Schedule SampleSchedule(const SearchTask& task, Sampler::TRandomState* rand_state) override; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return An array with a single element returned from SampleSchedule * \sa ScheduleFnNode::SampleSchedule */ - Array GetSupport(const SearchTask& task, Sampler* sampler) override; + Array GetSupport(const SearchTask& task, Sampler::TRandomState* rand_state) override; static constexpr const char* _type_key = "meta_schedule.ScheduleFn"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SearchSpaceNode); @@ -94,27 +95,29 @@ ScheduleFn::ScheduleFn(PackedFunc sch_fn, Array postprocs) { /********** Sampling **********/ -bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, Sampler* sampler) { +bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, + Sampler::TRandomState* rand_state) { sch->EnterPostproc(); for (const Postproc& postproc : postprocs) { - if (!postproc->Apply(task, sch, sampler)) { + if (!postproc->Apply(task, sch, rand_state)) { return false; } } return true; } -Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, Sampler* sampler) { +Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, Sampler::TRandomState* rand_state) { Schedule sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/sampler->ForkSeed(), + /*seed=*/Sampler(rand_state).ForkSeed(), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); this->sch_fn_(sch); return sch; } -Array ScheduleFnNode::GetSupport(const SearchTask& task, Sampler* sampler) { - return {SampleSchedule(task, sampler)}; +Array ScheduleFnNode::GetSupport(const SearchTask& task, + Sampler::TRandomState* rand_state) { + return {SampleSchedule(task, rand_state)}; } /********** FFI **********/ diff --git a/src/meta_schedule/space/search_rule.cc b/src/meta_schedule/space/search_rule.cc index 406740cbb3..b481d36c46 100644 --- a/src/meta_schedule/space/search_rule.cc +++ b/src/meta_schedule/space/search_rule.cc @@ -27,8 +27,8 @@ namespace tvm { namespace meta_schedule { /**************** TIR Nodes ****************/ -using tir::ForNode; using tir::BlockNode; +using tir::ForNode; /********** Constructors **********/ diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index df1cd19148..bba910d8c4 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -134,12 +134,12 @@ class EvolutionaryNode : public SearchStrategyNode { * \param task The search task * \param space The search space * \param measurer The measurer that builds, runs and profiles sampled programs - * \param sampler The random number sampler + * \param rand_state The sampler's random state * \param verbose Whether or not in verbose mode * \return The best schedule found, NullOpt if no valid schedule is found */ Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, + const ProgramMeasurer& measurer, Sampler::TRandomState* rand_state, int verbose) override; /********** Stages in evolutionary search **********/ @@ -151,22 +151,22 @@ class EvolutionaryNode : public SearchStrategyNode { * \param support The support to be sampled from * \param task The search task * \param space The search space - * \param sampler The random number sampler + * \param rand_state The sampler's random state * \return The generated samples, all of which are not post-processed */ Array SampleInitPopulation(const Array& support, const SearchTask& task, - const SearchSpace& space, Sampler* sampler); + const SearchSpace& space, Sampler::TRandomState* rand_state); /*! * \brief Perform evolutionary search using genetic algorithm with the cost model * \param inits The initial population * \param task The search task * \param space The search space - * \param sampler The random number sampler + * \param rand_state The sampler's random state * \return An array of schedules, the sampling result */ Array EvolveWithCostModel(const Array& inits, const SearchTask& task, - const SearchSpace& space, Sampler* sampler); + const SearchSpace& space, Sampler::TRandomState* rand_state); /*! * \brief Pick a batch of samples for measurement with epsilon greedy @@ -174,12 +174,12 @@ class EvolutionaryNode : public SearchStrategyNode { * \param bests The best populations according to the cost model when picking top states * \param task The search task * \param space The search space - * \param sampler The random number sampler + * \param rand_state The sampler's random state * \return A list of schedules, result of epsilon-greedy sampling */ Array PickWithEpsGreedy(const Array& inits, const Array& bests, const SearchTask& task, const SearchSpace& space, - Sampler* sampler); + Sampler::TRandomState* rand_state); /*! * \brief Make measurements and update the cost model @@ -202,14 +202,14 @@ class EvolutionaryNode : public SearchStrategyNode { /*! * \brief Fork a sampler into `n` samplers * \param n The number of samplers to be forked - * \param sampler The sampler to be forked - * \return A list of samplers, the result of forking + * \param rand_state The sampler's random state + * \return A list of random states, the result of forking */ - static std::vector ForkSamplers(int n, Sampler* sampler) { - std::vector result; + static std::vector ForkSamplers(int n, Sampler::TRandomState* rand_state) { + std::vector result; result.reserve(n); for (int i = 0; i < n; ++i) { - result.emplace_back(sampler->ForkSeed()); + result.emplace_back(Sampler(rand_state).ForkSeed()); } return result; } @@ -226,19 +226,16 @@ class EvolutionaryNode : public SearchStrategyNode { /*! * \brief Replay the trace and do postprocessing - * \param n The number of samplers to be forked - * \param sampler The sampler to be forked - * \return A list of samplers, the result of forking */ static Optional ReplayTrace(const Trace& trace, const SearchTask& task, - const SearchSpace& space, Sampler* sampler, + const SearchSpace& space, Sampler::TRandomState* rand_state, const tir::PrimFunc& workload) { Schedule sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), workload}}), - /*seed=*/sampler->ForkSeed(), + /*seed=*/Sampler(rand_state).ForkSeed(), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); trace->ApplyToSchedule(sch, /*remove_postproc=*/true); - if (!space->Postprocess(task, sch, sampler)) { + if (!space->Postprocess(task, sch, rand_state)) { return NullOpt; } return sch; @@ -246,11 +243,12 @@ class EvolutionaryNode : public SearchStrategyNode { /*! * \brief Create a sampler function that picks mutators according to the mass function - * \param sampler The source of randomness + * \param rand_state The sampler's random state * \return The sampler created */ static std::function()> MakeMutatorSampler( - double p_mutate, const Map& mutator_probs, Sampler* sampler) { + double p_mutate, const Map& mutator_probs, + Sampler::TRandomState* rand_state) { CHECK(0.0 <= p_mutate && p_mutate <= 1.0) // << "ValueError: Probability should be within [0, 1], " << "but get `p_mutate = " << p_mutate << '\''; @@ -279,7 +277,7 @@ class EvolutionaryNode : public SearchStrategyNode { masses[i] /= total_mass_mutator; } } - auto idx_sampler = sampler->MakeMultinomial(masses); + auto idx_sampler = Sampler(rand_state).MakeMultinomial(masses); return [idx_sampler = std::move(idx_sampler), mutators = std::move(mutators)]() -> Optional { int i = idx_sampler(); @@ -312,7 +310,6 @@ class EvolutionaryNode : public SearchStrategyNode { * \param candidates The candidates for prediction * \param task The search task * \param space The search space - * \param sampler Source of randomness * \return The normalized throughput in the prediction */ std::vector PredictNormalizedThroughput(const std::vector& candidates, @@ -427,8 +424,8 @@ Evolutionary::Evolutionary(int total_measures, int num_measures_per_iteration, i CHECK_LE(num_measures_per_iteration, population) << "ValueError: requires `num_measures_per_iteration <= population`"; { - Sampler sampler(42); - EvolutionaryNode::MakeMutatorSampler(p_mutate, mutator_probs, &sampler); + Sampler::TRandomState rand_state = 42; + EvolutionaryNode::MakeMutatorSampler(p_mutate, mutator_probs, &rand_state); } ObjectPtr n = make_object(); n->total_measures = total_measures; @@ -447,23 +444,23 @@ Evolutionary::Evolutionary(int total_measures, int num_measures_per_iteration, i /********** Search **********/ Optional EvolutionaryNode::Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, - int verbose) { - Array support = space->GetSupport(task, sampler); + const ProgramMeasurer& measurer, + Sampler::TRandomState* rand_state, int verbose) { + Array support = space->GetSupport(task, rand_state); int iter = 1; for (int num_measured = 0; num_measured < this->total_measures; ++iter) { LOG(INFO) << "Evolutionary search: Iteration #" << iter << " | Measured: " << num_measured << "/" << this->total_measures; // `inits`: Sampled initial population, whose size is at most `this->population` LOG(INFO) << "Sampling initial population..."; - Array inits = SampleInitPopulation(support, task, space, sampler); + Array inits = SampleInitPopulation(support, task, space, rand_state); LOG(INFO) << "Initial population size: " << inits.size(); // `bests`: The best schedules according to the cost mode when explore the space using mutators LOG(INFO) << "Evolving..."; - Array bests = EvolveWithCostModel(inits, task, space, sampler); + Array bests = EvolveWithCostModel(inits, task, space, rand_state); // Pick candidates with eps greedy LOG(INFO) << "Picking with epsilon greedy where epsilon = " << eps_greedy; - Array picks = PickWithEpsGreedy(inits, bests, task, space, sampler); + Array picks = PickWithEpsGreedy(inits, bests, task, space, rand_state); // Run measurement, update cost model LOG(INFO) << "Sending " << picks.size() << " samples for measurement"; Array results = MeasureAndUpdateCostModel(task, picks, measurer, verbose); @@ -475,25 +472,26 @@ Optional EvolutionaryNode::Search(const SearchTask& task, const Search Array EvolutionaryNode::SampleInitPopulation(const Array& support, const SearchTask& task, const SearchSpace& space, - Sampler* global_sampler) { + Sampler::TRandomState* global_rand_state) { trace_cache_.clear(); std::vector results; results.reserve(this->population); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_samplers = ForkSamplers(num_threads, global_sampler); + std::vector thread_rand_states = + ForkSamplers(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); // Pick measured states int num_measured = this->population * this->init_measured_ratio; for (const Database::Entry& entry : database->GetTopK(num_measured, task)) { results.push_back(entry.trace.value()); } - auto f_proc_measured = [this, &results, &thread_samplers, &task, &space, thread_workloads]( + auto f_proc_measured = [this, &results, &thread_rand_states, &task, &space, thread_workloads]( int thread_id, int i) -> void { - Sampler* sampler = &thread_samplers[thread_id]; + Sampler::TRandomState* rand_state = &thread_rand_states[thread_id]; const Trace& trace = results[i]; if (Optional opt_sch = - ReplayTrace(trace, task, space, sampler, thread_workloads[thread_id])) { + ReplayTrace(trace, task, space, rand_state, thread_workloads[thread_id])) { Schedule sch = opt_sch.value(); this->AddCachedTrace(CachedTrace{trace.get(), sch, Repr(sch), -1.0}); } else { @@ -505,15 +503,17 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo // Pick unmeasured states std::atomic tot_fail_ct(0); std::atomic success_ct(0); - auto f_proc_unmeasured = [this, &results, &thread_samplers, &tot_fail_ct, &task, &space, &support, - &success_ct, thread_workloads](int thread_id, int i) -> void { - Sampler* sampler = &thread_samplers[thread_id]; + auto f_proc_unmeasured = [this, &results, &thread_rand_states, &tot_fail_ct, &task, &space, + &support, &success_ct, thread_workloads](int thread_id, int i) -> void { + Sampler::TRandomState* rand_state = &thread_rand_states[thread_id]; for (;;) { - Trace support_trace = support[sampler->SampleInt(0, support.size())]->trace().value(); + Trace support_trace = + support[Sampler(rand_state).SampleInt(0, support.size())]->trace().value(); Map decisions; try { - if (Optional opt_sch = ReplayTrace(Trace(support_trace->insts, decisions), task, - space, sampler, thread_workloads[thread_id])) { + if (Optional opt_sch = + ReplayTrace(Trace(support_trace->insts, decisions), task, space, rand_state, + thread_workloads[thread_id])) { Schedule sch = opt_sch.value(); Trace old_trace = sch->trace().value(); Trace trace(old_trace->insts, old_trace->decisions); @@ -547,24 +547,26 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, const SearchTask& task, const SearchSpace& space, - Sampler* global_sampler) { + Sampler::TRandomState* global_rand_state) { // The heap to record best schedule, we do not consider schedules that are already measured // Also we use `in_heap` to make sure items in the heap are de-duplicated SizedHeap heap(this->num_measures_per_iteration); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_samplers = ForkSamplers(num_threads, global_sampler); + std::vector thread_rand_states = + ForkSamplers(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); std::vector> thread_trace_samplers(num_threads); std::vector()>> thread_mutator_samplers(num_threads); std::vector trace_used; std::mutex trace_used_mutex; - auto f_set_sampler = [this, num_threads, &thread_samplers, &thread_trace_samplers, + auto f_set_sampler = [this, num_threads, &thread_rand_states, &thread_trace_samplers, &thread_mutator_samplers, &trace_used](const std::vector& scores) { for (int i = 0; i < num_threads; ++i) { - Sampler* sampler = &thread_samplers[i]; - thread_trace_samplers[i] = sampler->MakeMultinomial(scores); - thread_mutator_samplers[i] = MakeMutatorSampler(this->p_mutate, this->mutator_probs, sampler); + Sampler::TRandomState* rand_state = &thread_rand_states[i]; + thread_trace_samplers[i] = Sampler(rand_state).MakeMultinomial(scores); + thread_mutator_samplers[i] = + MakeMutatorSampler(this->p_mutate, this->mutator_probs, rand_state); } trace_used = std::vector(scores.size(), 0); }; @@ -595,11 +597,11 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, // Set threaded samplers, with probability from predicated normalized throughputs f_set_sampler(scores); // The worker function - auto f_find_candidate = [&thread_samplers, &thread_trace_samplers, &thread_mutator_samplers, + auto f_find_candidate = [&thread_rand_states, &thread_trace_samplers, &thread_mutator_samplers, &trace_used, &trace_used_mutex, &sch_curr, &sch_next, &task, &space, thread_workloads, this](int thread_id, int i) { // Prepare samplers - Sampler* sampler = &thread_samplers[thread_id]; + Sampler::TRandomState* rand_state = &thread_rand_states[thread_id]; const std::function& trace_sampler = thread_trace_samplers[thread_id]; const std::function()>& mutator_sampler = thread_mutator_samplers[thread_id]; @@ -613,10 +615,10 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, // Decision: mutate Mutator mutator = opt_mutator.value(); if (Optional opt_new_trace = - mutator->Apply(task, GetRef(cached_trace.trace), sampler)) { + mutator->Apply(task, GetRef(cached_trace.trace), rand_state)) { Trace new_trace = opt_new_trace.value(); if (Optional opt_sch = - ReplayTrace(new_trace, task, space, sampler, thread_workloads[thread_id])) { + ReplayTrace(new_trace, task, space, rand_state, thread_workloads[thread_id])) { Schedule sch = opt_sch.value(); CachedTrace new_cached_trace{new_trace.get(), sch, Repr(sch), -1.0}; this->AddCachedTrace(new_cached_trace); @@ -675,10 +677,11 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, Array EvolutionaryNode::PickWithEpsGreedy(const Array& inits, const Array& bests, const SearchTask& task, - const SearchSpace& space, Sampler* sampler) { + const SearchSpace& space, + Sampler::TRandomState* rand_state) { int num_rands = this->num_measures_per_iteration * this->eps_greedy; int num_bests = this->num_measures_per_iteration - num_rands; - std::vector rands = sampler->SampleWithoutReplacement(inits.size(), inits.size()); + std::vector rands = Sampler(rand_state).SampleWithoutReplacement(inits.size(), inits.size()); Array results; results.reserve(this->num_measures_per_iteration); for (int i = 0, i_bests = 0, i_rands = 0; i < this->num_measures_per_iteration; ++i) { @@ -780,11 +783,11 @@ struct Internal { static Array SampleInitPopulation(Evolutionary self, Array support, SearchTask task, SearchSpace space, Optional seed) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return self->SampleInitPopulation(support, task, space, &seeded); + return self->SampleInitPopulation(support, task, space, &rand_state); } /*! * \brief Perform evolutionary search using genetic algorithm with the cost model @@ -798,11 +801,11 @@ struct Internal { */ static Array EvolveWithCostModel(Evolutionary self, Array inits, SearchTask task, SearchSpace space, Optional seed) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return self->EvolveWithCostModel(inits, task, space, &seeded); + return self->EvolveWithCostModel(inits, task, space, &rand_state); } /*! @@ -816,11 +819,11 @@ struct Internal { static Array PickWithEpsGreedy(Evolutionary self, Array inits, Array bests, SearchTask task, SearchSpace space, Optional seed) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return self->PickWithEpsGreedy(inits, bests, task, space, &seeded); + return self->PickWithEpsGreedy(inits, bests, task, space, &rand_state); } /*! diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index c425c8e5c8..a62931ed1e 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -35,8 +35,9 @@ Mutator::Mutator(String name, FApply apply) { /********** Mutator **********/ -Optional MutatorNode::Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) { - return apply_(task, trace, sampler); +Optional MutatorNode::Apply(const SearchTask& task, const Trace& trace, + Sampler::TRandomState* rand_state) { + return apply_(task, trace, rand_state); } /********** MutateTileSize **********/ @@ -77,17 +78,18 @@ class MutatorTileSize { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) { + Optional Apply(const SearchTask& task, const Trace& trace, + Sampler::TRandomState* rand_state) { // Find instruction `SamplePerfectTile` whose extent > 1 and n_splits > 1 std::vector candidates = FindCandidates(trace); if (candidates.empty()) { return NullOpt; } - const Instruction& inst = candidates[sampler->SampleInt(0, candidates.size())]; + const Instruction& inst = candidates[Sampler(rand_state).SampleInt(0, candidates.size())]; std::vector tiles = CastDecision(trace->decisions.at(inst)); int n_splits = tiles.size(); // Choose two loops - int x = sampler->SampleInt(0, n_splits); + int x = Sampler(rand_state).SampleInt(0, n_splits); int y; if (tiles[x] == 1) { // need to guarantee that tiles[x] * tiles[y] > 1 @@ -98,10 +100,10 @@ class MutatorTileSize { idx.push_back(i); } } - y = idx[sampler->SampleInt(0, idx.size())]; + y = idx[Sampler(rand_state).SampleInt(0, idx.size())]; } else { // sample without replacement - y = sampler->SampleInt(0, n_splits - 1); + y = Sampler(rand_state).SampleInt(0, n_splits - 1); if (y >= x) { ++y; } @@ -115,7 +117,7 @@ class MutatorTileSize { int len_x, len_y; if (y != n_splits - 1) { do { - std::vector result = sampler->SamplePerfectTile(2, tiles[x] * tiles[y]); + std::vector result = Sampler(rand_state).SamplePerfectTile(2, tiles[x] * tiles[y]); len_x = result[0]; len_y = result[1]; } while (len_y == tiles[y]); @@ -132,7 +134,7 @@ class MutatorTileSize { if (len_y_space.empty()) { return NullOpt; } - len_y = len_y_space[sampler->SampleInt(0, len_y_space.size())]; + len_y = len_y_space[Sampler(rand_state).SampleInt(0, len_y_space.size())]; len_x = prod / len_y; } tiles[x] = len_x; @@ -142,9 +144,9 @@ class MutatorTileSize { }; Mutator MutateTileSize() { - auto f_apply = [](SearchTask task, Trace trace, void* sampler) -> Optional { + auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorTileSize mutator; - return mutator.Apply(task, trace, static_cast(sampler)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_tile_size", f_apply); } @@ -216,21 +218,22 @@ class MutatorComputeLocation { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) { + Optional Apply(const SearchTask& task, const Trace& trace, + Sampler::TRandomState* rand_state) { std::vector candidates = FindCandidates(trace, task->workload); if (candidates.empty()) { return NullOpt; } - const Candidate& candidate = candidates[sampler->SampleInt(0, candidates.size())]; - int loc = candidate.locs[sampler->SampleInt(0, candidate.locs.size())]; + const Candidate& candidate = candidates[Sampler(rand_state).SampleInt(0, candidates.size())]; + int loc = candidate.locs[Sampler(rand_state).SampleInt(0, candidate.locs.size())]; return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true); } }; Mutator MutateComputeLocation() { - auto f_apply = [](SearchTask task, Trace trace, void* sampler) -> Optional { + auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorComputeLocation mutator; - return mutator.Apply(task, trace, static_cast(sampler)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_compute_location", f_apply); } @@ -308,13 +311,14 @@ class MutatorAutoUnroll { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) { + Optional Apply(const SearchTask& task, const Trace& trace, + Sampler::TRandomState* rand_state) { std::vector candidates = FindCandidates(trace); if (candidates.empty()) { return NullOpt; } - const Candidate& candidate = candidates[sampler->SampleInt(0, candidates.size())]; - int result = sampler->MakeMultinomial(candidate.weights)(); + const Candidate& candidate = candidates[Sampler(rand_state).SampleInt(0, candidates.size())]; + int result = Sampler(rand_state).MakeMultinomial(candidate.weights)(); if (result >= candidate.ori_decision) { result++; } @@ -323,9 +327,9 @@ class MutatorAutoUnroll { }; Mutator MutateAutoUnroll() { - auto f_apply = [](SearchTask task, Trace trace, void* sampler) -> Optional { + auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorAutoUnroll mutator; - return mutator.Apply(task, trace, static_cast(sampler)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_unroll_depth", f_apply); } @@ -429,7 +433,8 @@ class MutatorParallel { return Candidate(Instruction{nullptr}, {}); } - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler) const { + Optional Apply(const SearchTask& task, const Trace& trace, + Sampler::TRandomState* rand_state) const { static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); int max_extent = GetTargetNumCores(task->target, &warned_num_cores_missing) * max_jobs_per_core - 1; @@ -439,7 +444,8 @@ class MutatorParallel { } const BlockRV& block = Downcast(candidate.inst->inputs[0]); const std::vector& extent_candidates = candidate.extent_candidates; - int parallel_size = extent_candidates[sampler->SampleInt(0, extent_candidates.size())]; + int parallel_size = + extent_candidates[Sampler(rand_state).SampleInt(0, extent_candidates.size())]; std::vector new_insts; for (const Instruction& inst : trace->insts) { @@ -464,8 +470,8 @@ class MutatorParallel { Mutator MutateParallel(const int& max_jobs_per_core) { MutatorParallel mutator(max_jobs_per_core); - auto f_apply = [mutator](SearchTask task, Trace trace, void* sampler) -> Optional { - return mutator.Apply(task, trace, static_cast(sampler)); + auto f_apply = [mutator](SearchTask task, Trace trace, void* rand_state) -> Optional { + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_parallel", f_apply); } @@ -479,11 +485,11 @@ struct Internal { */ static Optional Apply(Mutator mutator, SearchTask task, Trace trace, Optional seed) { - Sampler seeded; + Sampler::TRandomState rand_state; if (seed.defined()) { - seeded.Seed(seed.value()); + Sampler(&rand_state).Seed(seed.value()); } - return mutator->Apply(task, trace, &seeded); + return mutator->Apply(task, trace, &rand_state); } }; diff --git a/src/meta_schedule/strategy/mutator.h b/src/meta_schedule/strategy/mutator.h index 79d1e8ccf8..acc11cbbee 100644 --- a/src/meta_schedule/strategy/mutator.h +++ b/src/meta_schedule/strategy/mutator.h @@ -44,10 +44,11 @@ class MutatorNode : public Object { * \brief Mutate the schedule by applying the mutation * \param task The search task * \param trace The trace to be mutated - * \param sampler The random number sampler + * \param rand_state The sampler's random state * \return The new schedule after mutation, NullOpt if mutation fails */ - Optional Apply(const SearchTask& task, const Trace& trace, Sampler* sampler); + Optional Apply(const SearchTask& task, const Trace& trace, + Sampler::TRandomState* rand_state); static constexpr const char* _type_key = "meta_schedule.Mutator"; TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); diff --git a/src/meta_schedule/strategy/replay.cc b/src/meta_schedule/strategy/replay.cc index b6f24cf70d..c3966438f3 100644 --- a/src/meta_schedule/strategy/replay.cc +++ b/src/meta_schedule/strategy/replay.cc @@ -51,7 +51,7 @@ class ReplayNode : public SearchStrategyNode { * \return The best schedule found, NullOpt if no valid schedule is found */ Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, + const ProgramMeasurer& measurer, Sampler::TRandomState* rand_state, int verbose) override; static constexpr const char* _type_key = "meta_schedule.Replay"; @@ -86,21 +86,21 @@ Replay::Replay(int batch_size, int num_trials) { /********** Search **********/ Optional ReplayNode::Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler* sampler, - int verbose) { - std::vector thread_samplers; + const ProgramMeasurer& measurer, + Sampler::TRandomState* rand_state, int verbose) { + std::vector thread_rand_states; std::vector thread_measure_inputs; - thread_samplers.reserve(this->batch_size); + thread_rand_states.reserve(this->batch_size); thread_measure_inputs.reserve(this->batch_size); for (int i = 0; i < batch_size; ++i) { - thread_samplers.emplace_back(sampler->ForkSeed()); + thread_rand_states.emplace_back(Sampler(rand_state).ForkSeed()); thread_measure_inputs.emplace_back(nullptr); } - auto worker = [&task, &space, &thread_samplers, &thread_measure_inputs](int thread_id, int i) { - Sampler* sampler = &thread_samplers[i]; + auto worker = [&task, &space, &thread_rand_states, &thread_measure_inputs](int thread_id, int i) { + Sampler::TRandomState* rand_state = &thread_rand_states[i]; for (;;) { - Schedule sch = space->SampleSchedule(task, sampler); - if (space->Postprocess(task, sch, sampler)) { + Schedule sch = space->SampleSchedule(task, rand_state); + if (space->Postprocess(task, sch, rand_state)) { thread_measure_inputs[i] = MeasureInput(task, sch); break; } diff --git a/src/support/rng.h b/src/support/rng.h index 7b4123d9d0..d6eda38638 100644 --- a/src/support/rng.h +++ b/src/support/rng.h @@ -99,6 +99,7 @@ class RandomNumberGenerator { * \return The next current random state value in the type of result_type. */ result_type next_state() { + if (increment == 0 && *rand_state_ptr == 0) *rand_state_ptr = 1; (*rand_state_ptr) = ((*rand_state_ptr) * multiplier + increment) % modulus; return *rand_state_ptr; } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 447f40cdc0..8415a94776 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -28,7 +28,7 @@ Schedule Schedule::Concrete(IRModule mod, int64_t seed, int debug_mode, ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); n->error_render_level_ = error_render_level; - n->sampler_.Seed(seed); + Sampler(&n->rand_state_).Seed(seed); n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); @@ -185,7 +185,7 @@ Schedule ConcreteScheduleNode::Copy(int64_t new_seed) const { Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; n->analyzer_ = std::make_unique(); - n->sampler_.Seed(new_seed); + Sampler(&n->rand_state_).Seed(new_seed); return Schedule(std::move(n)); } @@ -218,7 +218,7 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int int max_innermost_factor, Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::SamplePerfectTile(state_, &this->sampler_, this->GetSRef(loop_rv), n, + return CreateRV(tir::SamplePerfectTile(state_, &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); } @@ -227,7 +227,7 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::SampleCategorical(state_, &this->sampler_, candidates, probs, &decision)); + return CreateRV(tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); } @@ -235,7 +235,7 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV( - tir::SampleComputeLocation(state_, &this->sampler_, this->GetSRef(block_rv), &decision)); + tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); TVM_TIR_SCHEDULE_END("sample-compute-location", this->error_render_level_); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 429b372004..8cf054c35e 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -42,7 +42,7 @@ class ConcreteScheduleNode : public ScheduleNode { /*! \brief The level of error rendering */ ScheduleErrorRenderLevel error_render_level_; /*! \brief Source of randomness */ - Sampler sampler_; + Sampler::TRandomState rand_state_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ @@ -53,7 +53,7 @@ class ConcreteScheduleNode : public ScheduleNode { // `error_render_level_` is not visited // `state_` is not visited // `error_render_level_` is not visited - // `sampler_` is not visited + // `rand_state_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visitied } @@ -66,9 +66,11 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } - Schedule Copy(int64_t new_seed = -1) const override; - void Seed(int64_t new_seed = -1) final { this->sampler_.Seed(new_seed); } - int64_t ForkSeed() final { return this->sampler_.ForkSeed(); } + Schedule Copy(Sampler::TRandomState new_seed = -1) const override; + void Seed(Sampler::TRandomState new_seed = -1) final { + Sampler(&this->rand_state_).Seed(new_seed); + } + Sampler::TRandomState ForkSeed() final { return Sampler(&this->rand_state_).ForkSeed(); } public: /******** Lookup random variables ********/ @@ -83,7 +85,7 @@ class ConcreteScheduleNode : public ScheduleNode { void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } using ScheduleNode::GetSRef; - + public: /******** Schedule: Sampling ********/ Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index c30e53f9ff..22b28201c9 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -26,19 +26,20 @@ namespace tvm { namespace tir { - class Sampler; /******** Schedule: Sampling ********/ -TVM_DLL std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler, +TVM_DLL std::vector SamplePerfectTile(tir::ScheduleState self, + Sampler::TRandomState* rand_state, const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision); -TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, Sampler* sampler, +TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandomState* rand_state, const Array& candidates, const Array& probs, Optional* decision); -TVM_DLL tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler* sampler, +TVM_DLL tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, + Sampler::TRandomState* rand_state, const tir::StmtSRef& block_sref, Optional* decision); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 66f3f8179d..6d403b1d30 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -22,7 +22,7 @@ namespace tvm { namespace tir { -std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler, +std::vector SamplePerfectTile(tir::ScheduleState self, Sampler::TRandomState* rand_state, const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision) { @@ -52,7 +52,8 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler result[0] = len; } else { // Case 3. Use fresh new sampling result - std::vector sampled = sampler->SamplePerfectTile(n, extent, max_innermost_factor); + std::vector sampled = + Sampler(rand_state).SamplePerfectTile(n, extent, max_innermost_factor); result = std::vector(sampled.begin(), sampled.end()); ICHECK_LE(sampled.back(), max_innermost_factor); } @@ -60,7 +61,7 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler* sampler return result; } -int64_t SampleCategorical(tir::ScheduleState self, Sampler* sampler, +int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandomState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { int i = -1; @@ -71,14 +72,14 @@ int64_t SampleCategorical(tir::ScheduleState self, Sampler* sampler, CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - i = sampler->MakeMultinomial(AsVector(probs))(); + i = Sampler(rand_state).MakeMultinomial(AsVector(probs))(); ICHECK(0 <= i && i < n); } *decision = Integer(i); return candidates[i]; } -tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler* sampler, +tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler::TRandomState* rand_state, const tir::StmtSRef& block_sref, Optional* decision) { // Find all possible compute-at locations Array loop_srefs = tir::CollectComputeLocation(self, block_sref); @@ -111,7 +112,7 @@ tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler* sampler, } } else { // Sample possible combinations - i = sampler->SampleInt(-2, choices.size()); + i = Sampler(rand_state).SampleInt(-2, choices.size()); if (i >= 0) { i = choices[i]; } diff --git a/src/tir/schedule/sampler.cc b/src/tir/schedule/sampler.cc index 52d62b88f5..e50e69b43e 100644 --- a/src/tir/schedule/sampler.cc +++ b/src/tir/schedule/sampler.cc @@ -126,15 +126,14 @@ struct PrimeTable { } }; -int Sampler::ForkSeed() { - uint32_t a = this->rand_(); - uint32_t b = this->rand_(); - uint32_t c = this->rand_(); - uint32_t d = this->rand_(); - return (a ^ b) * (c ^ d) % 1145141; +Sampler::TRandomState Sampler::ForkSeed() { + // In order for reproducibility, we computer the new seed using sampler's RNG's current random + // state and a different set of parameters. Note that 32767 & 1999999973 are prime numbers. + Sampler::TRandomState ret = (this->rand_.random_state() * 32767) % 1999999973; + this->rand_.next_state(); + return ret; } - -void Sampler::Seed(int seed) { this->rand_.seed(seed); } +void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.seed(seed); } int Sampler::SampleInt(int min_inclusive, int max_exclusive) { if (min_inclusive + 1 == max_exclusive) { @@ -358,9 +357,10 @@ static inline bool IsCudaTarget(const Target& target) { return false; } -std::vector> Sampler::SampleShapeGenericTiles( - const std::vector& n_splits, const std::vector& max_extents, - const Target& target, int max_innermost_factor) { +std::vector> Sampler::SampleShapeGenericTiles(const std::vector& n_splits, + const std::vector& max_extents, + const Target& target, + int max_innermost_factor) { std::vector> ret_split_factors; if (IsCudaTarget(target)) { @@ -374,9 +374,9 @@ std::vector> Sampler::SampleShapeGenericTiles( int max_threads_per_block; int max_innermost_factor; int max_vthread; - } constraints = { - ExtractInt(target, "shared_memory_per_block"), ExtractInt(target, "registers_per_block"), - ExtractInt(target, "max_threads_per_block"), max_innermost_factor, 8}; + } constraints = {ExtractInt(target, "shared_memory_per_block"), + ExtractInt(target, "registers_per_block"), + ExtractInt(target, "max_threads_per_block"), max_innermost_factor, 8}; for (const int n_split : n_splits) { ret_split_factors.push_back(std::vector(n_split, 1)); @@ -403,8 +403,7 @@ std::vector> Sampler::SampleShapeGenericTiles( do { all_below_max_extents = true; - num_threads_factor_scheme = - SamplePerfectTile(num_spatial_axes, num_threads_per_block); + num_threads_factor_scheme = SamplePerfectTile(num_spatial_axes, num_threads_per_block); for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { if (n_splits[iter_id] == 4) { if (num_threads_factor_scheme[spatial_iter_id] > max_extents[iter_id]) { @@ -429,119 +428,100 @@ std::vector> Sampler::SampleShapeGenericTiles( auto sample_factors = [&](std::function continue_predicate, std::function max_extent, std::function factor_to_assign) { - std::vector iter_max_extents; - std::vector factors_to_assign; - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - size_t iter_max_extent = max_extent(iter_id), factor_to_assign; - - std::uniform_int_distribution<> dist(1, iter_max_extent); - factor_to_assign = SampleInt(1, iter_max_extent); - - if (n_splits[iter_id] == 4) { - reg_usage *= factor_to_assign; - } else { - shmem_usage *= factor_to_assign; - } - iter_max_extents.push_back(iter_max_extent); - factors_to_assign.push_back(factor_to_assign); - } - // shuffle the factors - std::vector factors_to_assign_bak = factors_to_assign; - Shuffle(factors_to_assign.begin(), factors_to_assign.end()); - // make sure that the shuffle is valid - bool valid_shuffle = true; - std::vector::iterator iter_max_extents_it = iter_max_extents.begin(), - factors_to_assign_it = factors_to_assign.begin(); - - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - int iter_max_extent = *iter_max_extents_it; - if (*factors_to_assign_it > iter_max_extent) { - valid_shuffle = false; - } - ++iter_max_extents_it; - ++factors_to_assign_it; - } - if (!valid_shuffle) { - factors_to_assign = std::move(factors_to_assign_bak); - } - // do the actual assignment - factors_to_assign_it = factors_to_assign.begin(); - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - factor_to_assign(iter_id) = *factors_to_assign_it; - ++factors_to_assign_it; - } - }; + std::vector iter_max_extents; + std::vector factors_to_assign; + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + size_t iter_max_extent = max_extent(iter_id), factor_to_assign; + + std::uniform_int_distribution<> dist(1, iter_max_extent); + factor_to_assign = SampleInt(1, iter_max_extent); + + if (n_splits[iter_id] == 4) { + reg_usage *= factor_to_assign; + } else { + shmem_usage *= factor_to_assign; + } + iter_max_extents.push_back(iter_max_extent); + factors_to_assign.push_back(factor_to_assign); + } + // shuffle the factors + std::vector factors_to_assign_bak = factors_to_assign; + Shuffle(factors_to_assign.begin(), factors_to_assign.end()); + // make sure that the shuffle is valid + bool valid_shuffle = true; + std::vector::iterator iter_max_extents_it = iter_max_extents.begin(), + factors_to_assign_it = factors_to_assign.begin(); + + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + int iter_max_extent = *iter_max_extents_it; + if (*factors_to_assign_it > iter_max_extent) { + valid_shuffle = false; + } + ++iter_max_extents_it; + ++factors_to_assign_it; + } + if (!valid_shuffle) { + factors_to_assign = std::move(factors_to_assign_bak); + } + // do the actual assignment + factors_to_assign_it = factors_to_assign.begin(); + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + factor_to_assign(iter_id) = *factors_to_assign_it; + ++factors_to_assign_it; + } + }; sample_factors( [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 4) || - (iter_id != last_spatial_iter_id); + return (n_splits[iter_id] != 4) || (iter_id != last_spatial_iter_id); }, [&](const size_t iter_id) -> int { - size_t max_vthread_extent = - std::min(constraints.max_vthread, - max_extents[iter_id] / ret_split_factors[iter_id][1]); + size_t max_vthread_extent = std::min( + constraints.max_vthread, max_extents[iter_id] / ret_split_factors[iter_id][1]); max_vthread_extent = - std::min(constraints.max_vthread, - constraints.max_local_memory_per_block / reg_usage); + std::min(constraints.max_vthread, constraints.max_local_memory_per_block / reg_usage); return max_vthread_extent; }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][0]; - } - ); + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][0]; }); // factor[3] (innermost) sample_factors( [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 4) || - (iter_id == last_spatial_iter_id); + return (n_splits[iter_id] != 4) || (iter_id == last_spatial_iter_id); }, [&](const size_t iter_id) -> int { int max_innermost_extent = - std::min(max_innermost_factor, - max_extents[iter_id] / ret_split_factors[iter_id][0] - / ret_split_factors[iter_id][1]); + std::min(max_innermost_factor, max_extents[iter_id] / ret_split_factors[iter_id][0] / + ret_split_factors[iter_id][1]); max_innermost_extent = - std::min(max_innermost_extent, - constraints.max_local_memory_per_block / reg_usage); + std::min(max_innermost_extent, constraints.max_local_memory_per_block / reg_usage); return max_innermost_extent; }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][3]; - } - ); + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][3]; }); // factor[2] - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 4); - }, - [&](const size_t iter_id) -> size_t { - size_t max_2nd_innermost_extent = - std::min(max_extents[iter_id] / ret_split_factors[iter_id][0] - / ret_split_factors[iter_id][1] / ret_split_factors[iter_id][3], - constraints.max_local_memory_per_block / reg_usage - ); - return max_2nd_innermost_extent; - }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][2]; - } - ); + sample_factors([&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 4); }, + [&](const size_t iter_id) -> size_t { + size_t max_2nd_innermost_extent = + std::min(max_extents[iter_id] / ret_split_factors[iter_id][0] / + ret_split_factors[iter_id][1] / ret_split_factors[iter_id][3], + constraints.max_local_memory_per_block / reg_usage); + return max_2nd_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][2]; }); for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { if (n_splits[iter_id] == 4) { - shmem_usage += ret_split_factors[iter_id][0] * ret_split_factors[iter_id][1] - * ret_split_factors[iter_id][2] * ret_split_factors[iter_id][3]; + shmem_usage += ret_split_factors[iter_id][0] * ret_split_factors[iter_id][1] * + ret_split_factors[iter_id][2] * ret_split_factors[iter_id][3]; } } if (shmem_usage > static_cast(constraints.max_shared_memory_per_block / sizeof(float))) { @@ -550,40 +530,25 @@ std::vector> Sampler::SampleShapeGenericTiles( // repeat similar procedure for reduction axes // rfactor[1] (innermost) sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 2); - }, + [&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 2); }, [&](const size_t iter_id) -> int { - int max_innermost_extent = - std::min(max_innermost_factor, max_extents[iter_id]); - max_innermost_extent = - std::min(max_innermost_extent, - static_cast( - constraints.max_shared_memory_per_block / sizeof(float) / shmem_usage - )); + int max_innermost_extent = std::min(max_innermost_factor, max_extents[iter_id]); + max_innermost_extent = std::min(max_innermost_extent, + static_cast(constraints.max_shared_memory_per_block / + sizeof(float) / shmem_usage)); return max_innermost_extent; }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][1]; - } - ); + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][1]; }); // rfactor[0] - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 2); - }, - [&](const size_t iter_id) -> size_t { - size_t max_2nd_innermost_extent = - std::min(max_extents[iter_id] / ret_split_factors[iter_id][1], - static_cast( - constraints.max_shared_memory_per_block / sizeof(float) / shmem_usage - )); - return max_2nd_innermost_extent; - }, - [&](const size_t iter_id) -> int& { - return ret_split_factors[iter_id][0]; - } - ); + sample_factors([&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 2); }, + [&](const size_t iter_id) -> size_t { + size_t max_2nd_innermost_extent = + std::min(max_extents[iter_id] / ret_split_factors[iter_id][1], + static_cast(constraints.max_shared_memory_per_block / + sizeof(float) / shmem_usage)); + return max_2nd_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][0]; }); } // if (IsCudaTarget(target)) return ret_split_factors; } diff --git a/src/tir/schedule/sampler.h b/src/tir/schedule/sampler.h index 5aa87f984b..da1ff4b238 100644 --- a/src/tir/schedule/sampler.h +++ b/src/tir/schedule/sampler.h @@ -24,19 +24,31 @@ #include #include +#include "../support/rng.h" namespace tvm { class Target; namespace tir { -/*! \brief Random number sampler used for sampling in meta schedule */ +/*! + * \brief Sampler based on random number generator for sampling in meta schedule. + * \note Typical usage is like Sampler(&random_state).SamplingFunc(...). + */ class Sampler { public: - /*! \brief Return a seed that can be used to create a new sampler */ - int ForkSeed(); - /*! \brief Re-seed the random number generator */ - void Seed(int seed); + /*! Random state type for random number generator. */ + using TRandomState = support::RandomNumberGenerator::result_type; + /*! + * \brief Return a random state value that can be used as seed for new samplers. + * \return The random state value to be used as seed for new samplers. + */ + TRandomState ForkSeed(); + /*! + * \brief Re-seed the random number generator + * \param seed The random state value given used to re-seed the RNG. + */ + void Seed(TRandomState seed); /*! * \brief Sample an integer in [min_inclusive, max_exclusive) * \param min_inclusive The left boundary, inclusive @@ -56,7 +68,7 @@ class Sampler { * \param begin_it The begin iterator * \param end_it The end iterator */ - template + template void Shuffle(RandomAccessIterator begin_it, RandomAccessIterator end_it); /*! * \brief Sample n tiling factors of the specific extent @@ -123,28 +135,25 @@ class Sampler { * \return A list of indices, samples drawn, unsorted and index starting from 0 */ std::vector SampleWithoutReplacement(int n, int k); + /*! \brief The default constructor function for Sampler */ + Sampler() = default; /*! - * \brief Constructor. Construct a sampler seeded with std::random_device + * \brief Constructor. Construct a sampler with a given random state pointer for its RNG. + * \param random_state The given pointer to random state used to construct the RNG. + * \note The random state is neither initialized not modified by this constructor. */ - Sampler() : Sampler(std::random_device /**/ {}()) {} - /*! - * \brief Constructor. Construct a sampler seeded with the specific integer - * \param seed The random seed - */ - explicit Sampler(int seed) : rand_(seed) {} + explicit Sampler(TRandomState* random_state) : rand_(random_state) {} private: - /*! \brief The random number generator */ - std::minstd_rand rand_; + /*! \brief The random number generator for sampling. */ + support::RandomNumberGenerator rand_; }; - -template +template void Sampler::Shuffle(RandomAccessIterator begin_it, RandomAccessIterator end_it) { std::shuffle(begin_it, end_it, rand_); } - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e2883ddb8f..f741dd2e45 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -21,24 +21,24 @@ namespace tvm { namespace tir { -Schedule Schedule::Traced(IRModule mod, int64_t seed, int debug_mode, +Schedule Schedule::Traced(IRModule mod, Sampler::TRandomState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); n->error_render_level_ = error_render_level; - n->sampler_.Seed(seed); + Sampler(&n->rand_state_).Seed(seed); n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); return Schedule(std::move(n)); } -Schedule TracedScheduleNode::Copy(int64_t new_seed) const { +Schedule TracedScheduleNode::Copy(Sampler::TRandomState new_seed) const { ObjectPtr n = make_object(); ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; n->analyzer_ = std::make_unique(); - n->sampler_.Seed(new_seed); + Sampler(&n->rand_state_).Seed(new_seed); n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); return Schedule(std::move(n)); } @@ -48,8 +48,9 @@ Schedule TracedScheduleNode::Copy(int64_t new_seed) const { Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision) { - Array results = CreateRV(tir::SamplePerfectTile( - this->state_, &this->sampler_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); + Array results = + CreateRV(tir::SamplePerfectTile(this->state_, &this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision)); static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -63,8 +64,8 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { - ExprRV result = - CreateRV(tir::SampleCategorical(this->state_, &this->sampler_, candidates, probs, &decision)); + ExprRV result = CreateRV( + tir::SampleCategorical(this->state_, &this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -77,7 +78,7 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, Optional decision) { - LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->sampler_, + LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation"); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ff362efa99..ed6582060e 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -34,7 +34,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void VisitAttrs(tvm::AttrVisitor* v) { // `state_` is not visited // `error_render_level_` is not visited - // `sampler_` is not visited + // `rand_state_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visitied // `trace_` is not visited diff --git a/tests/cpp/meta_schedule_test.cc b/tests/cpp/meta_schedule_test.cc index e65621a563..2d71598c15 100644 --- a/tests/cpp/meta_schedule_test.cc +++ b/tests/cpp/meta_schedule_test.cc @@ -20,13 +20,13 @@ #include #include -#include "../../../src/meta_schedule/sampler.h" +#include "../../../src/tir/schedule/sampler.h" TEST(Simplify, Sampler) { int64_t current = 100; for (int i = 0; i < 10; i++) { - tvm::meta_schedule::Sampler(¤t).SampleInt(0, 100); - tvm::meta_schedule::Sampler(¤t).SampleUniform(3, -1, 0); + tvm::tir::Sampler(¤t).SampleInt(0, 100); + tvm::tir::Sampler(¤t).SampleUniform(3, -1, 0); } } From d69e6783dd142d92e75f01eaf4d466ad9f0ac346 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Jul 2021 17:23:51 -0700 Subject: [PATCH 03/23] Update include headers. --- src/meta_schedule/space/schedule_fn.cc | 2 +- src/tir/schedule/primitive.h | 2 ++ src/tir/schedule/sampler.h | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/meta_schedule/space/schedule_fn.cc b/src/meta_schedule/space/schedule_fn.cc index 4097b4b26a..c479fefb41 100644 --- a/src/meta_schedule/space/schedule_fn.cc +++ b/src/meta_schedule/space/schedule_fn.cc @@ -47,7 +47,7 @@ class ScheduleFnNode : public SearchSpaceNode { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param rand_state The sampler random state + * \param rand_state The sampler's random state */ bool Postprocess(const SearchTask& task, const Schedule& sch, Sampler::TRandomState* rand_state) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 22b28201c9..0b6cf1f909 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -23,6 +23,8 @@ #include +#include "sampler.h" + namespace tvm { namespace tir { diff --git a/src/tir/schedule/sampler.h b/src/tir/schedule/sampler.h index da1ff4b238..c77129357f 100644 --- a/src/tir/schedule/sampler.h +++ b/src/tir/schedule/sampler.h @@ -24,7 +24,7 @@ #include #include -#include "../support/rng.h" +#include "../../support/rng.h" namespace tvm { class Target; From 8f3dce16618a49d48a2bde2336ff50e62302156f Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Jul 2021 17:26:35 -0700 Subject: [PATCH 04/23] Add new line --- tests/cpp/meta_schedule_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/meta_schedule_test.cc b/tests/cpp/meta_schedule_test.cc index 2d71598c15..4b68da7003 100644 --- a/tests/cpp/meta_schedule_test.cc +++ b/tests/cpp/meta_schedule_test.cc @@ -34,4 +34,4 @@ int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); -} \ No newline at end of file +} From aeb20bff429635342977aaba7ffa01f3f0dbab20 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 1 Aug 2021 20:53:53 -0700 Subject: [PATCH 05/23] Fix function name capital char & remove unused constructor function. --- src/support/rng.h | 5 +---- src/tir/schedule/sampler.cc | 2 +- tests/python/meta_schedule/test_meta_schedule_feature.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/support/rng.h b/src/support/rng.h index d6eda38638..6506fe45dd 100644 --- a/src/support/rng.h +++ b/src/support/rng.h @@ -50,9 +50,6 @@ class RandomNumberGenerator { /*! \brief The modulus */ static constexpr result_type modulus = 2147483647; - /*! \brief Construct a null random number generator. */ - RandomNumberGenerator() { rand_state_ptr = nullptr; } - /*! * \brief Construct a random number generator with a random state pointer. * \param random_state The random state pointer given in result_type*. @@ -66,7 +63,7 @@ class RandomNumberGenerator { * \note The seed is used to initialize the random number generator and the random state would be * changed to next random state by calling the next_state() function. */ - void seed(result_type state = 1) { + void Seed(result_type state = 1) { state %= modulus; // Make sure the seed is within the range of the modulus. if (state < 0) state += modulus; // The congruential engine is always non-negative. ICHECK(rand_state_ptr != nullptr); // Make sure the pointer is not null. diff --git a/src/tir/schedule/sampler.cc b/src/tir/schedule/sampler.cc index e50e69b43e..5ef3436c02 100644 --- a/src/tir/schedule/sampler.cc +++ b/src/tir/schedule/sampler.cc @@ -133,7 +133,7 @@ Sampler::TRandomState Sampler::ForkSeed() { this->rand_.next_state(); return ret; } -void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.seed(seed); } +void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.Seed(seed); } int Sampler::SampleInt(int min_inclusive, int max_exclusive) { if (min_inclusive + 1 == max_exclusive) { diff --git a/tests/python/meta_schedule/test_meta_schedule_feature.py b/tests/python/meta_schedule/test_meta_schedule_feature.py index 57b0edefe8..49ead3062c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_feature.py +++ b/tests/python/meta_schedule/test_meta_schedule_feature.py @@ -675,7 +675,7 @@ def _check_compute(feature): 1, 0, 0, 29, 20, 23, 14, 1, 0, 0, 18, 20.005626, 4.0874629, 25, 16, 19, 10.0014086, 1, ], write_feature=[ - 0, 1, 0, 29, 12.000352, 23, 9.002815, 1, 0, 0, 10.001408, 13.000176, 8.005625, 21, 4.087463, 15, + 0, 1, 0, 29, 12.000352, 23, 9.002815, 1, 0, 0, 10.001408, 13.000176, 8.005625, 21, 4.087463, 15, #pylint: disable=line-too-long 1.584963, 1, ], # fmt: on From ec17e82b3b284dc8fb4ac94b06c8e9d58c6e16a4 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 2 Aug 2021 16:06:23 -0700 Subject: [PATCH 06/23] Modify according to reviews. --- {src => include/tvm}/support/rng.h | 26 ++++++++++++++------------ src/tir/schedule/sampler.cc | 1 + src/tir/schedule/sampler.h | 8 ++++---- 3 files changed, 19 insertions(+), 16 deletions(-) rename {src => include/tvm}/support/rng.h (74%) diff --git a/src/support/rng.h b/include/tvm/support/rng.h similarity index 74% rename from src/support/rng.h rename to include/tvm/support/rng.h index 6506fe45dd..92e0bc61e8 100644 --- a/src/support/rng.h +++ b/include/tvm/support/rng.h @@ -34,9 +34,11 @@ namespace tvm { namespace support { /*! - * \brief The random number generator is implemented as a linear congruential engine. + * \brief This linear congruential engine is a drop-in replacement for and stricly corresponds to + * std::minstd_rand but designed to be serializable and strictly reproducible. Specifically + * implemented for meta schedule but also reusable for other purposes. */ -class RandomNumberGenerator { +class LinearCongruentialEngine { public: /*! \brief The result type is defined as int64_t here for sampler usage. */ using result_type = int64_t; @@ -55,7 +57,7 @@ class RandomNumberGenerator { * \param random_state The random state pointer given in result_type*. * \note The random state is not initialized here. You may need to call seed function. */ - explicit RandomNumberGenerator(result_type* random_state) { rand_state_ptr = random_state; } + explicit LinearCongruentialEngine(result_type* random_state) { rand_state_ptr_ = random_state; } /*! * \brief Change the start random state of RNG with the seed of a new random state value. @@ -64,10 +66,10 @@ class RandomNumberGenerator { * changed to next random state by calling the next_state() function. */ void Seed(result_type state = 1) { - state %= modulus; // Make sure the seed is within the range of the modulus. - if (state < 0) state += modulus; // The congruential engine is always non-negative. - ICHECK(rand_state_ptr != nullptr); // Make sure the pointer is not null. - *rand_state_ptr = state; // Change pointed random state to given random state value. + state %= modulus; // Make sure the seed is within the range of the modulus. + if (state < 0) state += modulus; // The congruential engine is always non-negative. + ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. + *rand_state_ptr_ = state; // Change pointed random state to given random state value. next_state(); }; @@ -81,7 +83,7 @@ class RandomNumberGenerator { * \brief Fetch the current random state. * \return The current random state value in the type of result_type. */ - result_type random_state() { return *rand_state_ptr; } + result_type random_state() { return *rand_state_ptr_; } /*! * \brief Operator to fetch the current random state. @@ -96,13 +98,13 @@ class RandomNumberGenerator { * \return The next current random state value in the type of result_type. */ result_type next_state() { - if (increment == 0 && *rand_state_ptr == 0) *rand_state_ptr = 1; - (*rand_state_ptr) = ((*rand_state_ptr) * multiplier + increment) % modulus; - return *rand_state_ptr; + if (increment == 0 && *rand_state_ptr_ == 0) *rand_state_ptr_ = 1; + (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus; + return *rand_state_ptr_; } private: - result_type* rand_state_ptr; + result_type* rand_state_ptr_; }; } // namespace support diff --git a/src/tir/schedule/sampler.cc b/src/tir/schedule/sampler.cc index 5ef3436c02..f2c7534511 100644 --- a/src/tir/schedule/sampler.cc +++ b/src/tir/schedule/sampler.cc @@ -133,6 +133,7 @@ Sampler::TRandomState Sampler::ForkSeed() { this->rand_.next_state(); return ret; } + void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.Seed(seed); } int Sampler::SampleInt(int min_inclusive, int max_exclusive) { diff --git a/src/tir/schedule/sampler.h b/src/tir/schedule/sampler.h index c77129357f..6bb3312fb0 100644 --- a/src/tir/schedule/sampler.h +++ b/src/tir/schedule/sampler.h @@ -19,12 +19,12 @@ #ifndef TVM_TIR_SCHEDULE_SAMPLER_H_ #define TVM_TIR_SCHEDULE_SAMPLER_H_ +#include + #include #include #include #include - -#include "../../support/rng.h" namespace tvm { class Target; @@ -38,7 +38,7 @@ namespace tir { class Sampler { public: /*! Random state type for random number generator. */ - using TRandomState = support::RandomNumberGenerator::result_type; + using TRandomState = support::LinearCongruentialEngine::result_type; /*! * \brief Return a random state value that can be used as seed for new samplers. * \return The random state value to be used as seed for new samplers. @@ -146,7 +146,7 @@ class Sampler { private: /*! \brief The random number generator for sampling. */ - support::RandomNumberGenerator rand_; + support::LinearCongruentialEngine rand_; }; template From cb2158fd315fdc68ac8fcdde8b2839905ee0d1c9 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 2 Aug 2021 22:18:04 -0700 Subject: [PATCH 07/23] Change file name and add checks. --- include/tvm/support/{rng.h => random_engine.h} | 9 ++++++--- src/tir/schedule/sampler.cc | 2 +- src/tir/schedule/sampler.h | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) rename include/tvm/support/{rng.h => random_engine.h} (95%) diff --git a/include/tvm/support/rng.h b/include/tvm/support/random_engine.h similarity index 95% rename from include/tvm/support/rng.h rename to include/tvm/support/random_engine.h index 92e0bc61e8..7756b08bc0 100644 --- a/include/tvm/support/rng.h +++ b/include/tvm/support/random_engine.h @@ -57,7 +57,11 @@ class LinearCongruentialEngine { * \param random_state The random state pointer given in result_type*. * \note The random state is not initialized here. You may need to call seed function. */ - explicit LinearCongruentialEngine(result_type* random_state) { rand_state_ptr_ = random_state; } + explicit LinearCongruentialEngine(result_type* random_state) { + ICHECK(random_state != nullptr); // Make sure the pointer is not null. + rand_state_ptr_ = random_state; + seed(*random_state); + } /*! * \brief Change the start random state of RNG with the seed of a new random state value. @@ -65,12 +69,11 @@ class LinearCongruentialEngine { * \note The seed is used to initialize the random number generator and the random state would be * changed to next random state by calling the next_state() function. */ - void Seed(result_type state = 1) { + void seed(result_type state = 1) { state %= modulus; // Make sure the seed is within the range of the modulus. if (state < 0) state += modulus; // The congruential engine is always non-negative. ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. *rand_state_ptr_ = state; // Change pointed random state to given random state value. - next_state(); }; /*! \brief The minimum possible value of random state here. */ diff --git a/src/tir/schedule/sampler.cc b/src/tir/schedule/sampler.cc index f2c7534511..cd2f2e1133 100644 --- a/src/tir/schedule/sampler.cc +++ b/src/tir/schedule/sampler.cc @@ -134,7 +134,7 @@ Sampler::TRandomState Sampler::ForkSeed() { return ret; } -void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.Seed(seed); } +void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.seed(seed); } int Sampler::SampleInt(int min_inclusive, int max_exclusive) { if (min_inclusive + 1 == max_exclusive) { diff --git a/src/tir/schedule/sampler.h b/src/tir/schedule/sampler.h index 6bb3312fb0..60828c3237 100644 --- a/src/tir/schedule/sampler.h +++ b/src/tir/schedule/sampler.h @@ -19,7 +19,7 @@ #ifndef TVM_TIR_SCHEDULE_SAMPLER_H_ #define TVM_TIR_SCHEDULE_SAMPLER_H_ -#include +#include #include #include From 0ad9e260808c72926070b85a935b467636db294d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 3 Aug 2021 18:04:03 -0700 Subject: [PATCH 08/23] Fix seed = -1 situation. --- include/tvm/support/random_engine.h | 87 +++++++++---------- src/meta_schedule/autotune.cc | 4 +- .../cost_model/rand_cost_model.cc | 3 +- src/meta_schedule/search.cc | 10 +-- src/meta_schedule/space/postproc.cc | 2 +- src/meta_schedule/strategy/evolutionary.cc | 6 +- src/meta_schedule/strategy/mutator.cc | 2 +- src/tir/schedule/concrete_schedule.cc | 2 + src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/sampler.cc | 10 +-- src/tir/schedule/traced_schedule.cc | 2 + 11 files changed, 68 insertions(+), 61 deletions(-) diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 7756b08bc0..dfbeab406f 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -18,13 +18,12 @@ */ /*! - * \file rng.h - * \brief Random number generator, for Sampler and Sampling - * functions. + * \file random_engine.h + * \brief Random number generator, for Sampler and Sampling functions. */ -#ifndef TVM_SUPPORT_RNG_H_ -#define TVM_SUPPORT_RNG_H_ +#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ +#define TVM_SUPPORT_RANDOM_ENGINE_H_ #include @@ -37,10 +36,16 @@ namespace support { * \brief This linear congruential engine is a drop-in replacement for and stricly corresponds to * std::minstd_rand but designed to be serializable and strictly reproducible. Specifically * implemented for meta schedule but also reusable for other purposes. + * \note Part of std::linear_congruential_engine's member functions are not included, for full + * member functions of std::minstd_rand, please check out the following link: + * https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine */ class LinearCongruentialEngine { public: - /*! \brief The result type is defined as int64_t here for sampler usage. */ + /*! + * \brief The result type is defined as int64_t here for meta_schedule sampler usage. + * \note The type name is not in Google style because it is used in STL's distribution inferface. + */ using result_type = int64_t; /*! \brief The multiplier */ @@ -53,57 +58,51 @@ class LinearCongruentialEngine { static constexpr result_type modulus = 2147483647; /*! - * \brief Construct a random number generator with a random state pointer. - * \param random_state The random state pointer given in result_type*. - * \note The random state is not initialized here. You may need to call seed function. + * \brief The minimum possible value of random state here. + * \note The function name is uncapilized because it is used in STL's distribution inferface. */ - explicit LinearCongruentialEngine(result_type* random_state) { - ICHECK(random_state != nullptr); // Make sure the pointer is not null. - rand_state_ptr_ = random_state; - seed(*random_state); - } + result_type min() { return 0; } /*! - * \brief Change the start random state of RNG with the seed of a new random state value. - * \param random_state The random state given in result_type. - * \note The seed is used to initialize the random number generator and the random state would be - * changed to next random state by calling the next_state() function. + * \brief The maximum possible value of random state here. + * \note The function name is uncapilized because it is used in STL's distribution inferface. */ - void seed(result_type state = 1) { - state %= modulus; // Make sure the seed is within the range of the modulus. - if (state < 0) state += modulus; // The congruential engine is always non-negative. - ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. - *rand_state_ptr_ = state; // Change pointed random state to given random state value. - }; - - /*! \brief The minimum possible value of random state here. */ - result_type min() { return 0; } - - /*! \brief The maximum possible value of random state here. */ result_type max() { return modulus - 1; } /*! - * \brief Fetch the current random state. - * \return The current random state value in the type of result_type. + * \brief Operator to move the random state to the next and return the new random state. According + * to definition of linear congruential engine, the new random state value is computed as + * new_random_state = (current_random_state * multiplier + increment) % modulus. + * \return The next current random state value in the type of result_type. + * \note In case of potential overflow, please use Schrage multiplication algorithm to implement. + * We also assume the given rand state is not nullptr here. */ - result_type random_state() { return *rand_state_ptr_; } + result_type operator()() { + // Avoid getting all 0 given the current parameter set. + if (increment == 0 && *rand_state_ptr_ == 0) *rand_state_ptr_ = 1; + (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus; + return *rand_state_ptr_; + } /*! - * \brief Operator to fetch the current random state. - * \return The current random state value in the type of result_type. + * \brief Change the start random state of RNG with the seed of a new random state value. + * \param rand_state The random state given in result_type. */ - result_type operator()() { return next_state(); } + void Seed(result_type rand_state = 1) { + rand_state %= modulus; // Make sure the seed is within the range of modulus. + if (rand_state < 0) rand_state += modulus; // The congruential engine is always non-negative. + ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. + *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. + }; /*! - * \brief Move the random state to the next and return the new random state. According to - * definition of linear congruential engine, the new random state value is computed as - * new_random_state = (current_random_state * multiplier + increment) % modulus. - * \return The next current random state value in the type of result_type. + * \brief Construct a random number generator with a random state pointer. + * \param rand_state_ptr The random state pointer given in result_type*. + * \note The random state is not checked for whether it's nullptr and whether it's in the range of + * [0, modulus-1]. We assume the given random state is valid or the Seed function would be called. */ - result_type next_state() { - if (increment == 0 && *rand_state_ptr_ == 0) *rand_state_ptr_ = 1; - (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus; - return *rand_state_ptr_; + explicit LinearCongruentialEngine(result_type* rand_state_ptr) { + rand_state_ptr_ = rand_state_ptr; } private: @@ -113,4 +112,4 @@ class LinearCongruentialEngine { } // namespace support } // namespace tvm -#endif // TVM_SUPPORT_RNG_H_ +#endif // TVM_SUPPORT_RANDOM_ENGINE_H_ diff --git a/src/meta_schedule/autotune.cc b/src/meta_schedule/autotune.cc index 24d2a1bf39..7d6aec3e8a 100644 --- a/src/meta_schedule/autotune.cc +++ b/src/meta_schedule/autotune.cc @@ -24,8 +24,10 @@ namespace tvm { namespace meta_schedule { void TuneContextNode::Init(Optional seed) { - if (seed.defined()) { + if (seed.defined() && seed.value() != -1) { Sampler(&this->rand_state).Seed(seed.value()->value); + } else { + Sampler(&this->rand_state).Seed(std::random_device()()); } if (task.defined()) { task.value()->Init(this); diff --git a/src/meta_schedule/cost_model/rand_cost_model.cc b/src/meta_schedule/cost_model/rand_cost_model.cc index 92b78fa078..0fc8a49fa1 100644 --- a/src/meta_schedule/cost_model/rand_cost_model.cc +++ b/src/meta_schedule/cost_model/rand_cost_model.cc @@ -76,7 +76,8 @@ class RandCostModel : public CostModel { struct Internal { static RandCostModel New(Optional seed) { - return seed.defined() ? RandCostModel(seed.value()->value) : RandCostModel(); + return seed.defined() ? RandCostModel(seed.value()->value) + : RandCostModel(std::random_device()()); } }; diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index 83eb73e12f..0b0ca2c896 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -58,7 +58,7 @@ SearchTask::SearchTask(tir::PrimFunc workload, String task_name, Target target, */ TVM_DLL Optional AutoTune(SearchTask task, SearchSpace space, SearchStrategy strategy, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -108,7 +108,7 @@ struct Internal { */ static bool SearchSpacePostprocess(SearchSpace space, SearchTask task, Schedule sch, Optional seed) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -123,7 +123,7 @@ struct Internal { */ static Schedule SearchSpaceSampleSchedule(SearchSpace space, SearchTask task, Optional seed) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -139,7 +139,7 @@ struct Internal { */ static Array SearchSpaceGetSupport(SearchSpace space, SearchTask task, Optional seed) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -157,7 +157,7 @@ struct Internal { static Optional SearchStrategySearch(SearchStrategy strategy, SearchTask task, SearchSpace space, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index e0774a7ce6..bfac4a87cb 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -1119,7 +1119,7 @@ struct Internal { * \sa PostProcNode::Apply */ static bool Apply(Postproc self, SearchTask task, Schedule sch, Optional seed) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index bba910d8c4..725d6912ce 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -783,7 +783,7 @@ struct Internal { static Array SampleInitPopulation(Evolutionary self, Array support, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -801,7 +801,7 @@ struct Internal { */ static Array EvolveWithCostModel(Evolutionary self, Array inits, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -819,7 +819,7 @@ struct Internal { static Array PickWithEpsGreedy(Evolutionary self, Array inits, Array bests, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index a62931ed1e..42f34b9f05 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -485,7 +485,7 @@ struct Internal { */ static Optional Apply(Mutator mutator, SearchTask task, Trace trace, Optional seed) { - Sampler::TRandomState rand_state; + Sampler::TRandomState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8415a94776..be66eaeb52 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -28,6 +28,7 @@ Schedule Schedule::Concrete(IRModule mod, int64_t seed, int debug_mode, ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); n->error_render_level_ = error_render_level; + if (seed == -1) seed = std::random_device()(); Sampler(&n->rand_state_).Seed(seed); n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); @@ -185,6 +186,7 @@ Schedule ConcreteScheduleNode::Copy(int64_t new_seed) const { Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; n->analyzer_ = std::make_unique(); + if (new_seed == -1) new_seed = std::random_device()(); Sampler(&n->rand_state_).Seed(new_seed); return Schedule(std::move(n)); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 8cf054c35e..a4dbf61bd0 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -68,6 +68,7 @@ class ConcreteScheduleNode : public ScheduleNode { Optional trace() const override { return NullOpt; } Schedule Copy(Sampler::TRandomState new_seed = -1) const override; void Seed(Sampler::TRandomState new_seed = -1) final { + if (new_seed == -1) new_seed = std::random_device()(); Sampler(&this->rand_state_).Seed(new_seed); } Sampler::TRandomState ForkSeed() final { return Sampler(&this->rand_state_).ForkSeed(); } diff --git a/src/tir/schedule/sampler.cc b/src/tir/schedule/sampler.cc index cd2f2e1133..ac76a13786 100644 --- a/src/tir/schedule/sampler.cc +++ b/src/tir/schedule/sampler.cc @@ -127,14 +127,14 @@ struct PrimeTable { }; Sampler::TRandomState Sampler::ForkSeed() { - // In order for reproducibility, we computer the new seed using sampler's RNG's current random - // state and a different set of parameters. Note that 32767 & 1999999973 are prime numbers. - Sampler::TRandomState ret = (this->rand_.random_state() * 32767) % 1999999973; - this->rand_.next_state(); + // In order for reproducibility, we computer the new seed using sampler's RNG's random state and a + // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. + Sampler::TRandomState ret = (this->rand_() * 32767) % 1999999973; return ret; } -void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.seed(seed); } +// We don't need to check the seed here because it's checked in LCE's seed function. +void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.Seed(seed); } int Sampler::SampleInt(int min_inclusive, int max_exclusive) { if (min_inclusive + 1 == max_exclusive) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index f741dd2e45..29e4ac56ad 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -26,6 +26,7 @@ Schedule Schedule::Traced(IRModule mod, Sampler::TRandomState seed, int debug_mo ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); n->error_render_level_ = error_render_level; + if (seed == -1) seed = std::random_device()(); Sampler(&n->rand_state_).Seed(seed); n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); @@ -38,6 +39,7 @@ Schedule TracedScheduleNode::Copy(Sampler::TRandomState new_seed) const { ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; n->analyzer_ = std::make_unique(); + if (new_seed == -1) new_seed = std::random_device()(); Sampler(&n->rand_state_).Seed(new_seed); n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); return Schedule(std::move(n)); From b8e0d63869fbe7abf357fe9eda975f23523e3bf8 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 5 Aug 2021 16:13:33 -0700 Subject: [PATCH 09/23] Update random engine usage. --- include/tvm/support/random_engine.h | 54 +++++++------- src/meta_schedule/autotune.h | 2 +- .../cost_model/rand_cost_model.cc | 2 +- src/meta_schedule/search.cc | 10 +-- src/meta_schedule/search.h | 8 +-- src/meta_schedule/space/post_order_apply.cc | 12 ++-- src/meta_schedule/space/postproc.cc | 4 +- src/meta_schedule/space/postproc.h | 2 +- src/meta_schedule/space/schedule_fn.cc | 14 ++-- src/meta_schedule/strategy/evolutionary.cc | 44 ++++++------ src/meta_schedule/strategy/mutator.cc | 20 +++--- src/meta_schedule/strategy/mutator.h | 2 +- src/meta_schedule/strategy/replay.cc | 8 +-- src/tir/schedule/concrete_schedule.h | 8 +-- src/tir/schedule/primitive.h | 6 +- src/tir/schedule/primitive/sampling.cc | 6 +- src/tir/schedule/sampler.cc | 6 +- src/tir/schedule/sampler.h | 8 +-- src/tir/schedule/traced_schedule.cc | 4 +- tests/cpp/random_engine_test.cc | 71 +++++++++++++++++++ 20 files changed, 184 insertions(+), 107 deletions(-) create mode 100644 tests/cpp/random_engine_test.cc diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index dfbeab406f..e73c1193f4 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -27,18 +27,19 @@ #include -#include // for int64_t +#include // for uint64_t namespace tvm { namespace support { /*! - * \brief This linear congruential engine is a drop-in replacement for and stricly corresponds to - * std::minstd_rand but designed to be serializable and strictly reproducible. Specifically - * implemented for meta schedule but also reusable for other purposes. - * \note Part of std::linear_congruential_engine's member functions are not included, for full - * member functions of std::minstd_rand, please check out the following link: - * https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine + * \brief This linear congruential engine is a drop-in replacement for std::minstd_rand. It strictly + * corresponds to std::minstd_rand and is designed to be platform-independent. + * \note Our linear congruential engine is a complete implementation of + * std::uniform_random_bit_generator so it can be used as generator for any STL random number + * distribution. However, parts of std::linear_congruential_engine's member functions are not + * included for simplification. For full member functions of std::minstd_rand, please check out the + * following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine */ class LinearCongruentialEngine { public: @@ -46,16 +47,17 @@ class LinearCongruentialEngine { * \brief The result type is defined as int64_t here for meta_schedule sampler usage. * \note The type name is not in Google style because it is used in STL's distribution inferface. */ - using result_type = int64_t; + using result_type = uint64_t; + using TRandState = int64_t; /*! \brief The multiplier */ - static constexpr result_type multiplier = 48271; + static constexpr TRandState multiplier = 48271; /*! \brief The increment */ - static constexpr result_type increment = 0; + static constexpr TRandState increment = 0; /*! \brief The modulus */ - static constexpr result_type modulus = 2147483647; + static constexpr TRandState modulus = 2147483647; /*! * \brief The minimum possible value of random state here. @@ -71,15 +73,15 @@ class LinearCongruentialEngine { /*! * \brief Operator to move the random state to the next and return the new random state. According - * to definition of linear congruential engine, the new random state value is computed as + * to definition of linear congruential engine, the new random state value is computed as * new_random_state = (current_random_state * multiplier + increment) % modulus. * \return The next current random state value in the type of result_type. - * \note In case of potential overflow, please use Schrage multiplication algorithm to implement. - * We also assume the given rand state is not nullptr here. + * \note In order for better efficiency, the implementation here has a few assumptions: + * 1. The multiplication and addition won't overflow. + * 2. The given random state pointer `rand_state_ptr` is not nullptr. + * 3. The given random state `*(rand_state_ptr)` is in the range of [0, modulus - 1]. */ result_type operator()() { - // Avoid getting all 0 given the current parameter set. - if (increment == 0 && *rand_state_ptr_ == 0) *rand_state_ptr_ = 1; (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus; return *rand_state_ptr_; } @@ -88,25 +90,29 @@ class LinearCongruentialEngine { * \brief Change the start random state of RNG with the seed of a new random state value. * \param rand_state The random state given in result_type. */ - void Seed(result_type rand_state = 1) { + void Seed(TRandState rand_state = 1) { rand_state %= modulus; // Make sure the seed is within the range of modulus. - if (rand_state < 0) rand_state += modulus; // The congruential engine is always non-negative. - ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. - *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. - }; + if (rand_state == 0) + rand_state = 1; // Avoid getting all 0 given the current parameter set. + else if (rand_state < 0) + rand_state += modulus; // Make sure the rand state is non-negative. + ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. + *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. + } /*! * \brief Construct a random number generator with a random state pointer. * \param rand_state_ptr The random state pointer given in result_type*. * \note The random state is not checked for whether it's nullptr and whether it's in the range of - * [0, modulus-1]. We assume the given random state is valid or the Seed function would be called. + * [0, modulus-1]. We assume the given random state is valid or the Seed function would be + * called right after the constructor before any usage. */ - explicit LinearCongruentialEngine(result_type* rand_state_ptr) { + explicit LinearCongruentialEngine(TRandState* rand_state_ptr) { rand_state_ptr_ = rand_state_ptr; } private: - result_type* rand_state_ptr_; + TRandState* rand_state_ptr_; }; } // namespace support diff --git a/src/meta_schedule/autotune.h b/src/meta_schedule/autotune.h index 3196bb8f72..f56973fa42 100644 --- a/src/meta_schedule/autotune.h +++ b/src/meta_schedule/autotune.h @@ -44,7 +44,7 @@ class TuneContextNode : public runtime::Object { Array measure_callbacks; int num_threads; - Sampler::TRandomState rand_state; + Sampler::TRandState rand_state; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("task", &task); diff --git a/src/meta_schedule/cost_model/rand_cost_model.cc b/src/meta_schedule/cost_model/rand_cost_model.cc index 0fc8a49fa1..75304c3888 100644 --- a/src/meta_schedule/cost_model/rand_cost_model.cc +++ b/src/meta_schedule/cost_model/rand_cost_model.cc @@ -28,7 +28,7 @@ namespace meta_schedule { class RandCostModelNode : public CostModelNode { public: /*! \brief A random state for sampler to generate random numbers */ - Sampler::TRandomState rand_state; + Sampler::TRandState rand_state; void VisitAttrs(tvm::AttrVisitor* v) { // sampler is not visited diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index 0b0ca2c896..02fe13d2f0 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -58,7 +58,7 @@ SearchTask::SearchTask(tir::PrimFunc workload, String task_name, Target target, */ TVM_DLL Optional AutoTune(SearchTask task, SearchSpace space, SearchStrategy strategy, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -108,7 +108,7 @@ struct Internal { */ static bool SearchSpacePostprocess(SearchSpace space, SearchTask task, Schedule sch, Optional seed) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -123,7 +123,7 @@ struct Internal { */ static Schedule SearchSpaceSampleSchedule(SearchSpace space, SearchTask task, Optional seed) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -139,7 +139,7 @@ struct Internal { */ static Array SearchSpaceGetSupport(SearchSpace space, SearchTask task, Optional seed) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -157,7 +157,7 @@ struct Internal { static Optional SearchStrategySearch(SearchStrategy strategy, SearchTask task, SearchSpace space, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } diff --git a/src/meta_schedule/search.h b/src/meta_schedule/search.h index 9756cde094..7b4cfa6859 100644 --- a/src/meta_schedule/search.h +++ b/src/meta_schedule/search.h @@ -104,20 +104,20 @@ class SearchSpaceNode : public runtime::Object { * \param rand_state The sampler's random state */ virtual bool Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandomState* rand_state) = 0; + Sampler::TRandState* rand_state) = 0; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - virtual Schedule SampleSchedule(const SearchTask& task, Sampler::TRandomState* rand_state) = 0; + virtual Schedule SampleSchedule(const SearchTask& task, Sampler::TRandState* rand_state) = 0; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return The support of the search space. Any point from the search space should along to one of * the traces returned */ - virtual Array GetSupport(const SearchTask& task, Sampler::TRandomState* rand_state) = 0; + virtual Array GetSupport(const SearchTask& task, Sampler::TRandState* rand_state) = 0; static constexpr const char* _type_key = "meta_schedule.SearchSpace"; TVM_DECLARE_BASE_OBJECT_INFO(SearchSpaceNode, Object); @@ -158,7 +158,7 @@ class SearchStrategyNode : public Object { */ virtual Optional Search(const SearchTask& task, const SearchSpace& space, const ProgramMeasurer& measurer, - Sampler::TRandomState* rand_state, int verbose) = 0; + Sampler::TRandState* rand_state, int verbose) = 0; /*! \brief Explore the search space */ virtual void Search() { LOG(FATAL) << "NotImplemented"; } diff --git a/src/meta_schedule/space/post_order_apply.cc b/src/meta_schedule/space/post_order_apply.cc index 3ad13c02d8..9011b012ca 100644 --- a/src/meta_schedule/space/post_order_apply.cc +++ b/src/meta_schedule/space/post_order_apply.cc @@ -52,20 +52,20 @@ class PostOrderApplyNode : public SearchSpaceNode { * \param rand_state The sampler's random state */ bool Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandomState* rand_state) override; + Sampler::TRandState* rand_state) override; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - Schedule SampleSchedule(const SearchTask& task, Sampler::TRandomState* rand_state) override; + Schedule SampleSchedule(const SearchTask& task, Sampler::TRandState* rand_state) override; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return An array with a single element returned from SampleSchedule * \sa PostOrderApplyNode::SampleSchedule */ - Array GetSupport(const SearchTask& task, Sampler::TRandomState* rand_state) override; + Array GetSupport(const SearchTask& task, Sampler::TRandState* rand_state) override; static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SearchSpaceNode); @@ -98,7 +98,7 @@ PostOrderApply::PostOrderApply(Array stages, Array postpro /********** Sampling **********/ bool PostOrderApplyNode::Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { sch->EnterPostProc(); for (const Postproc& postproc : postprocs) { if (!postproc->Apply(task, sch, rand_state)) { @@ -109,7 +109,7 @@ bool PostOrderApplyNode::Postprocess(const SearchTask& task, const Schedule& sch } Schedule PostOrderApplyNode::SampleSchedule(const SearchTask& task, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { Array support = GetSupport(task, rand_state); ICHECK(!support.empty()) << "ValueError: Found null support"; int i = Sampler(rand_state).SampleInt(0, support.size()); @@ -149,7 +149,7 @@ class BlockCollector : public tir::StmtVisitor { }; Array PostOrderApplyNode::GetSupport(const SearchTask& task, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { using ScheduleAndUnvisitedBlocks = std::pair>; Array curr{ diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index bfac4a87cb..61ebb75ec0 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -39,7 +39,7 @@ Postproc::Postproc(String name, FProc proc) { /********** Postproc **********/ bool PostprocNode::Apply(const SearchTask& task, const Schedule& sch, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { return proc_(task, sch, rand_state); } @@ -1119,7 +1119,7 @@ struct Internal { * \sa PostProcNode::Apply */ static bool Apply(Postproc self, SearchTask task, Schedule sch, Optional seed) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } diff --git a/src/meta_schedule/space/postproc.h b/src/meta_schedule/space/postproc.h index d4786b2c32..3b388d4e7b 100644 --- a/src/meta_schedule/space/postproc.h +++ b/src/meta_schedule/space/postproc.h @@ -47,7 +47,7 @@ class PostprocNode : public Object { * \param rand_state The sampler's random state * \return If the post-processing succeeds */ - bool Apply(const SearchTask& task, const Schedule& sch, Sampler::TRandomState* rand_state); + bool Apply(const SearchTask& task, const Schedule& sch, Sampler::TRandState* rand_state); static constexpr const char* _type_key = "meta_schedule.Postproc"; TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); diff --git a/src/meta_schedule/space/schedule_fn.cc b/src/meta_schedule/space/schedule_fn.cc index c479fefb41..902529e200 100644 --- a/src/meta_schedule/space/schedule_fn.cc +++ b/src/meta_schedule/space/schedule_fn.cc @@ -50,20 +50,20 @@ class ScheduleFnNode : public SearchSpaceNode { * \param rand_state The sampler's random state */ bool Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandomState* rand_state) override; + Sampler::TRandState* rand_state) override; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - Schedule SampleSchedule(const SearchTask& task, Sampler::TRandomState* rand_state) override; + Schedule SampleSchedule(const SearchTask& task, Sampler::TRandState* rand_state) override; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return An array with a single element returned from SampleSchedule * \sa ScheduleFnNode::SampleSchedule */ - Array GetSupport(const SearchTask& task, Sampler::TRandomState* rand_state) override; + Array GetSupport(const SearchTask& task, Sampler::TRandState* rand_state) override; static constexpr const char* _type_key = "meta_schedule.ScheduleFn"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SearchSpaceNode); @@ -96,8 +96,8 @@ ScheduleFn::ScheduleFn(PackedFunc sch_fn, Array postprocs) { /********** Sampling **********/ bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandomState* rand_state) { - sch->EnterPostproc(); + Sampler::TRandState* rand_state) { + sch->EnterPostProc(); for (const Postproc& postproc : postprocs) { if (!postproc->Apply(task, sch, rand_state)) { return false; @@ -106,7 +106,7 @@ bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, return true; } -Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, Sampler::TRandomState* rand_state) { +Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, Sampler::TRandState* rand_state) { Schedule sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), /*seed=*/Sampler(rand_state).ForkSeed(), /*debug_mode=*/false, @@ -116,7 +116,7 @@ Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, Sampler::TRandom } Array ScheduleFnNode::GetSupport(const SearchTask& task, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { return {SampleSchedule(task, rand_state)}; } diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index 725d6912ce..5bfa415d23 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -139,7 +139,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \return The best schedule found, NullOpt if no valid schedule is found */ Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler::TRandomState* rand_state, + const ProgramMeasurer& measurer, Sampler::TRandState* rand_state, int verbose) override; /********** Stages in evolutionary search **********/ @@ -155,7 +155,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \return The generated samples, all of which are not post-processed */ Array SampleInitPopulation(const Array& support, const SearchTask& task, - const SearchSpace& space, Sampler::TRandomState* rand_state); + const SearchSpace& space, Sampler::TRandState* rand_state); /*! * \brief Perform evolutionary search using genetic algorithm with the cost model @@ -166,7 +166,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \return An array of schedules, the sampling result */ Array EvolveWithCostModel(const Array& inits, const SearchTask& task, - const SearchSpace& space, Sampler::TRandomState* rand_state); + const SearchSpace& space, Sampler::TRandState* rand_state); /*! * \brief Pick a batch of samples for measurement with epsilon greedy @@ -179,7 +179,7 @@ class EvolutionaryNode : public SearchStrategyNode { */ Array PickWithEpsGreedy(const Array& inits, const Array& bests, const SearchTask& task, const SearchSpace& space, - Sampler::TRandomState* rand_state); + Sampler::TRandState* rand_state); /*! * \brief Make measurements and update the cost model @@ -205,8 +205,8 @@ class EvolutionaryNode : public SearchStrategyNode { * \param rand_state The sampler's random state * \return A list of random states, the result of forking */ - static std::vector ForkSamplers(int n, Sampler::TRandomState* rand_state) { - std::vector result; + static std::vector ForkSamplers(int n, Sampler::TRandState* rand_state) { + std::vector result; result.reserve(n); for (int i = 0; i < n; ++i) { result.emplace_back(Sampler(rand_state).ForkSeed()); @@ -228,7 +228,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \brief Replay the trace and do postprocessing */ static Optional ReplayTrace(const Trace& trace, const SearchTask& task, - const SearchSpace& space, Sampler::TRandomState* rand_state, + const SearchSpace& space, Sampler::TRandState* rand_state, const tir::PrimFunc& workload) { Schedule sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), workload}}), /*seed=*/Sampler(rand_state).ForkSeed(), @@ -248,7 +248,7 @@ class EvolutionaryNode : public SearchStrategyNode { */ static std::function()> MakeMutatorSampler( double p_mutate, const Map& mutator_probs, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { CHECK(0.0 <= p_mutate && p_mutate <= 1.0) // << "ValueError: Probability should be within [0, 1], " << "but get `p_mutate = " << p_mutate << '\''; @@ -424,7 +424,7 @@ Evolutionary::Evolutionary(int total_measures, int num_measures_per_iteration, i CHECK_LE(num_measures_per_iteration, population) << "ValueError: requires `num_measures_per_iteration <= population`"; { - Sampler::TRandomState rand_state = 42; + Sampler::TRandState rand_state = 42; EvolutionaryNode::MakeMutatorSampler(p_mutate, mutator_probs, &rand_state); } ObjectPtr n = make_object(); @@ -445,7 +445,7 @@ Evolutionary::Evolutionary(int total_measures, int num_measures_per_iteration, i Optional EvolutionaryNode::Search(const SearchTask& task, const SearchSpace& space, const ProgramMeasurer& measurer, - Sampler::TRandomState* rand_state, int verbose) { + Sampler::TRandState* rand_state, int verbose) { Array support = space->GetSupport(task, rand_state); int iter = 1; for (int num_measured = 0; num_measured < this->total_measures; ++iter) { @@ -472,13 +472,13 @@ Optional EvolutionaryNode::Search(const SearchTask& task, const Search Array EvolutionaryNode::SampleInitPopulation(const Array& support, const SearchTask& task, const SearchSpace& space, - Sampler::TRandomState* global_rand_state) { + Sampler::TRandState* global_rand_state) { trace_cache_.clear(); std::vector results; results.reserve(this->population); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_rand_states = + std::vector thread_rand_states = ForkSamplers(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); // Pick measured states @@ -488,7 +488,7 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo } auto f_proc_measured = [this, &results, &thread_rand_states, &task, &space, thread_workloads]( int thread_id, int i) -> void { - Sampler::TRandomState* rand_state = &thread_rand_states[thread_id]; + Sampler::TRandState* rand_state = &thread_rand_states[thread_id]; const Trace& trace = results[i]; if (Optional opt_sch = ReplayTrace(trace, task, space, rand_state, thread_workloads[thread_id])) { @@ -505,7 +505,7 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo std::atomic success_ct(0); auto f_proc_unmeasured = [this, &results, &thread_rand_states, &tot_fail_ct, &task, &space, &support, &success_ct, thread_workloads](int thread_id, int i) -> void { - Sampler::TRandomState* rand_state = &thread_rand_states[thread_id]; + Sampler::TRandState* rand_state = &thread_rand_states[thread_id]; for (;;) { Trace support_trace = support[Sampler(rand_state).SampleInt(0, support.size())]->trace().value(); @@ -547,13 +547,13 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, const SearchTask& task, const SearchSpace& space, - Sampler::TRandomState* global_rand_state) { + Sampler::TRandState* global_rand_state) { // The heap to record best schedule, we do not consider schedules that are already measured // Also we use `in_heap` to make sure items in the heap are de-duplicated SizedHeap heap(this->num_measures_per_iteration); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_rand_states = + std::vector thread_rand_states = ForkSamplers(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); std::vector> thread_trace_samplers(num_threads); @@ -563,7 +563,7 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, auto f_set_sampler = [this, num_threads, &thread_rand_states, &thread_trace_samplers, &thread_mutator_samplers, &trace_used](const std::vector& scores) { for (int i = 0; i < num_threads; ++i) { - Sampler::TRandomState* rand_state = &thread_rand_states[i]; + Sampler::TRandState* rand_state = &thread_rand_states[i]; thread_trace_samplers[i] = Sampler(rand_state).MakeMultinomial(scores); thread_mutator_samplers[i] = MakeMutatorSampler(this->p_mutate, this->mutator_probs, rand_state); @@ -601,7 +601,7 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, &trace_used, &trace_used_mutex, &sch_curr, &sch_next, &task, &space, thread_workloads, this](int thread_id, int i) { // Prepare samplers - Sampler::TRandomState* rand_state = &thread_rand_states[thread_id]; + Sampler::TRandState* rand_state = &thread_rand_states[thread_id]; const std::function& trace_sampler = thread_trace_samplers[thread_id]; const std::function()>& mutator_sampler = thread_mutator_samplers[thread_id]; @@ -678,7 +678,7 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, Array EvolutionaryNode::PickWithEpsGreedy(const Array& inits, const Array& bests, const SearchTask& task, const SearchSpace& space, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { int num_rands = this->num_measures_per_iteration * this->eps_greedy; int num_bests = this->num_measures_per_iteration - num_rands; std::vector rands = Sampler(rand_state).SampleWithoutReplacement(inits.size(), inits.size()); @@ -783,7 +783,7 @@ struct Internal { static Array SampleInitPopulation(Evolutionary self, Array support, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -801,7 +801,7 @@ struct Internal { */ static Array EvolveWithCostModel(Evolutionary self, Array inits, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } @@ -819,7 +819,7 @@ struct Internal { static Array PickWithEpsGreedy(Evolutionary self, Array inits, Array bests, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index 42f34b9f05..fd5b7cb37c 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -36,7 +36,7 @@ Mutator::Mutator(String name, FApply apply) { /********** Mutator **********/ Optional MutatorNode::Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { return apply_(task, trace, rand_state); } @@ -79,7 +79,7 @@ class MutatorTileSize { } Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { // Find instruction `SamplePerfectTile` whose extent > 1 and n_splits > 1 std::vector candidates = FindCandidates(trace); if (candidates.empty()) { @@ -146,7 +146,7 @@ class MutatorTileSize { Mutator MutateTileSize() { auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorTileSize mutator; - return mutator.Apply(task, trace, static_cast(rand_state)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_tile_size", f_apply); } @@ -219,7 +219,7 @@ class MutatorComputeLocation { } Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { std::vector candidates = FindCandidates(trace, task->workload); if (candidates.empty()) { return NullOpt; @@ -233,7 +233,7 @@ class MutatorComputeLocation { Mutator MutateComputeLocation() { auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorComputeLocation mutator; - return mutator.Apply(task, trace, static_cast(rand_state)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_compute_location", f_apply); } @@ -312,7 +312,7 @@ class MutatorAutoUnroll { } Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandomState* rand_state) { + Sampler::TRandState* rand_state) { std::vector candidates = FindCandidates(trace); if (candidates.empty()) { return NullOpt; @@ -329,7 +329,7 @@ class MutatorAutoUnroll { Mutator MutateAutoUnroll() { auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorAutoUnroll mutator; - return mutator.Apply(task, trace, static_cast(rand_state)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_unroll_depth", f_apply); } @@ -434,7 +434,7 @@ class MutatorParallel { } Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandomState* rand_state) const { + Sampler::TRandState* rand_state) const { static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); int max_extent = GetTargetNumCores(task->target, &warned_num_cores_missing) * max_jobs_per_core - 1; @@ -471,7 +471,7 @@ class MutatorParallel { Mutator MutateParallel(const int& max_jobs_per_core) { MutatorParallel mutator(max_jobs_per_core); auto f_apply = [mutator](SearchTask task, Trace trace, void* rand_state) -> Optional { - return mutator.Apply(task, trace, static_cast(rand_state)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_parallel", f_apply); } @@ -485,7 +485,7 @@ struct Internal { */ static Optional Apply(Mutator mutator, SearchTask task, Trace trace, Optional seed) { - Sampler::TRandomState rand_state = std::random_device()(); + Sampler::TRandState rand_state = std::random_device()(); if (seed.defined()) { Sampler(&rand_state).Seed(seed.value()); } diff --git a/src/meta_schedule/strategy/mutator.h b/src/meta_schedule/strategy/mutator.h index acc11cbbee..8c2ef293f7 100644 --- a/src/meta_schedule/strategy/mutator.h +++ b/src/meta_schedule/strategy/mutator.h @@ -48,7 +48,7 @@ class MutatorNode : public Object { * \return The new schedule after mutation, NullOpt if mutation fails */ Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandomState* rand_state); + Sampler::TRandState* rand_state); static constexpr const char* _type_key = "meta_schedule.Mutator"; TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); diff --git a/src/meta_schedule/strategy/replay.cc b/src/meta_schedule/strategy/replay.cc index c3966438f3..0ea85e5e0b 100644 --- a/src/meta_schedule/strategy/replay.cc +++ b/src/meta_schedule/strategy/replay.cc @@ -51,7 +51,7 @@ class ReplayNode : public SearchStrategyNode { * \return The best schedule found, NullOpt if no valid schedule is found */ Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler::TRandomState* rand_state, + const ProgramMeasurer& measurer, Sampler::TRandState* rand_state, int verbose) override; static constexpr const char* _type_key = "meta_schedule.Replay"; @@ -87,8 +87,8 @@ Replay::Replay(int batch_size, int num_trials) { Optional ReplayNode::Search(const SearchTask& task, const SearchSpace& space, const ProgramMeasurer& measurer, - Sampler::TRandomState* rand_state, int verbose) { - std::vector thread_rand_states; + Sampler::TRandState* rand_state, int verbose) { + std::vector thread_rand_states; std::vector thread_measure_inputs; thread_rand_states.reserve(this->batch_size); thread_measure_inputs.reserve(this->batch_size); @@ -97,7 +97,7 @@ Optional ReplayNode::Search(const SearchTask& task, const SearchSpace& thread_measure_inputs.emplace_back(nullptr); } auto worker = [&task, &space, &thread_rand_states, &thread_measure_inputs](int thread_id, int i) { - Sampler::TRandomState* rand_state = &thread_rand_states[i]; + Sampler::TRandState* rand_state = &thread_rand_states[i]; for (;;) { Schedule sch = space->SampleSchedule(task, rand_state); if (space->Postprocess(task, sch, rand_state)) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index a4dbf61bd0..2ca6910e85 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -42,7 +42,7 @@ class ConcreteScheduleNode : public ScheduleNode { /*! \brief The level of error rendering */ ScheduleErrorRenderLevel error_render_level_; /*! \brief Source of randomness */ - Sampler::TRandomState rand_state_; + Sampler::TRandState rand_state_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ @@ -66,12 +66,12 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } - Schedule Copy(Sampler::TRandomState new_seed = -1) const override; - void Seed(Sampler::TRandomState new_seed = -1) final { + Schedule Copy(Sampler::TRandState new_seed = -1) const override; + void Seed(Sampler::TRandState new_seed = -1) final { if (new_seed == -1) new_seed = std::random_device()(); Sampler(&this->rand_state_).Seed(new_seed); } - Sampler::TRandomState ForkSeed() final { return Sampler(&this->rand_state_).ForkSeed(); } + Sampler::TRandState ForkSeed() final { return Sampler(&this->rand_state_).ForkSeed(); } public: /******** Lookup random variables ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0b6cf1f909..408ef4e2eb 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -33,15 +33,15 @@ class Sampler; /******** Schedule: Sampling ********/ TVM_DLL std::vector SamplePerfectTile(tir::ScheduleState self, - Sampler::TRandomState* rand_state, + Sampler::TRandState* rand_state, const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision); -TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandomState* rand_state, +TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); TVM_DLL tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, - Sampler::TRandomState* rand_state, + Sampler::TRandState* rand_state, const tir::StmtSRef& block_sref, Optional* decision); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 6d403b1d30..26624c4030 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -22,7 +22,7 @@ namespace tvm { namespace tir { -std::vector SamplePerfectTile(tir::ScheduleState self, Sampler::TRandomState* rand_state, +std::vector SamplePerfectTile(tir::ScheduleState self, Sampler::TRandState* rand_state, const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision) { @@ -61,7 +61,7 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler::TRandom return result; } -int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandomState* rand_state, +int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { int i = -1; @@ -79,7 +79,7 @@ int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandomState* rand_s return candidates[i]; } -tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler::TRandomState* rand_state, +tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler::TRandState* rand_state, const tir::StmtSRef& block_sref, Optional* decision) { // Find all possible compute-at locations Array loop_srefs = tir::CollectComputeLocation(self, block_sref); diff --git a/src/tir/schedule/sampler.cc b/src/tir/schedule/sampler.cc index ac76a13786..2ea9f23a89 100644 --- a/src/tir/schedule/sampler.cc +++ b/src/tir/schedule/sampler.cc @@ -126,15 +126,15 @@ struct PrimeTable { } }; -Sampler::TRandomState Sampler::ForkSeed() { +Sampler::TRandState Sampler::ForkSeed() { // In order for reproducibility, we computer the new seed using sampler's RNG's random state and a // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. - Sampler::TRandomState ret = (this->rand_() * 32767) % 1999999973; + Sampler::TRandState ret = (this->rand_() * 32767) % 1999999973; return ret; } // We don't need to check the seed here because it's checked in LCE's seed function. -void Sampler::Seed(Sampler::TRandomState seed) { this->rand_.Seed(seed); } +void Sampler::Seed(Sampler::TRandState seed) { this->rand_.Seed(seed); } int Sampler::SampleInt(int min_inclusive, int max_exclusive) { if (min_inclusive + 1 == max_exclusive) { diff --git a/src/tir/schedule/sampler.h b/src/tir/schedule/sampler.h index 60828c3237..c3dd5704fa 100644 --- a/src/tir/schedule/sampler.h +++ b/src/tir/schedule/sampler.h @@ -38,17 +38,17 @@ namespace tir { class Sampler { public: /*! Random state type for random number generator. */ - using TRandomState = support::LinearCongruentialEngine::result_type; + using TRandState = support::LinearCongruentialEngine::TRandState; /*! * \brief Return a random state value that can be used as seed for new samplers. * \return The random state value to be used as seed for new samplers. */ - TRandomState ForkSeed(); + TRandState ForkSeed(); /*! * \brief Re-seed the random number generator * \param seed The random state value given used to re-seed the RNG. */ - void Seed(TRandomState seed); + void Seed(TRandState seed); /*! * \brief Sample an integer in [min_inclusive, max_exclusive) * \param min_inclusive The left boundary, inclusive @@ -142,7 +142,7 @@ class Sampler { * \param random_state The given pointer to random state used to construct the RNG. * \note The random state is neither initialized not modified by this constructor. */ - explicit Sampler(TRandomState* random_state) : rand_(random_state) {} + explicit Sampler(TRandState* random_state) : rand_(random_state) {} private: /*! \brief The random number generator for sampling. */ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 29e4ac56ad..237662c576 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -21,7 +21,7 @@ namespace tvm { namespace tir { -Schedule Schedule::Traced(IRModule mod, Sampler::TRandomState seed, int debug_mode, +Schedule Schedule::Traced(IRModule mod, Sampler::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); @@ -34,7 +34,7 @@ Schedule Schedule::Traced(IRModule mod, Sampler::TRandomState seed, int debug_mo return Schedule(std::move(n)); } -Schedule TracedScheduleNode::Copy(Sampler::TRandomState new_seed) const { +Schedule TracedScheduleNode::Copy(Sampler::TRandState new_seed) const { ObjectPtr n = make_object(); ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; diff --git a/tests/cpp/random_engine_test.cc b/tests/cpp/random_engine_test.cc new file mode 100644 index 0000000000..6435d5dc3c --- /dev/null +++ b/tests/cpp/random_engine_test.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +TEST(RandomEngine, Randomness) { + int64_t rand_state = 0; + + tvm::support::LinearCongruentialEngine rng(&rand_state); + rng.Seed(0x114514); + + bool covered[100]; + memset(covered, 0, sizeof(covered)); + for (int i = 0; i < 100000; i++) { + covered[rng() % 100] = true; + } + for (int i = 0; i < 100; i++) { + ICHECK(covered[i]); + } +} + +TEST(RandomEngine, Reproducibility) { + int64_t rand_state_a = 0, rand_state_b = 0; + tvm::support::LinearCongruentialEngine rng_a(&rand_state_a), rng_b(&rand_state_b); + + rng_a.Seed(0x23456789); + rng_b.Seed(0x23456789); + + for (int i = 0; i < 100000; i++) { + ICHECK_EQ(rng_a(), rng_b()); + } +} + +TEST(RandomEngine, Serialization) { + int64_t rand_state_a = 0, rand_state_b = 0; + tvm::support::LinearCongruentialEngine rng_a(&rand_state_a), rng_b(&rand_state_b); + + rng_a.Seed(0x56728); + + rand_state_b = rand_state_a; + for (int i = 0; i < 100000; i++) ICHECK_EQ(rng_a(), rng_b()); + + for (int i = 0; i < 123456; i++) rng_a(); + + rand_state_b = rand_state_a; + for (int i = 0; i < 100000; i++) ICHECK_EQ(rng_a(), rng_b()); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} \ No newline at end of file From bf27b94adc1b426d4de8fe83618510b6365fa1df Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 5 Aug 2021 16:16:29 -0700 Subject: [PATCH 10/23] Add new line. --- tests/cpp/random_engine_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/random_engine_test.cc b/tests/cpp/random_engine_test.cc index 6435d5dc3c..10b8afa0ee 100644 --- a/tests/cpp/random_engine_test.cc +++ b/tests/cpp/random_engine_test.cc @@ -68,4 +68,4 @@ int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); -} \ No newline at end of file +} From 8b3484f17a884b051a8d5e8f23359b52fdf18264 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 5 Aug 2021 16:46:15 -0700 Subject: [PATCH 11/23] Change type and consistency. --- include/tvm/tir/schedule/schedule.h | 12 +++++++----- src/meta_schedule/cost_model/rand_cost_model.cc | 8 +++----- src/tir/schedule/concrete_schedule.cc | 6 +++--- src/tir/schedule/traced_schedule.cc | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9148a34760..195b3c351e 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -21,6 +21,8 @@ #include +#include "../../src/tir/schedule/sampler.h" + namespace tvm { namespace tir { @@ -113,12 +115,12 @@ class ScheduleNode : public runtime::Object { * 3) All the random variables are valid in the copy, pointing to the correpsonding sref * reconstructed */ - virtual Schedule Copy(int64_t seed = -1) const = 0; + virtual Schedule Copy(Sampler::TRandState seed = -1) const = 0; /*! * \brief Seed the randomness * \param seed The new random seed, -1 if use device random, otherwise non-negative */ - virtual void Seed(int64_t seed = -1) = 0; + virtual void Seed(Sampler::TRandState seed = -1) = 0; /*! \brief Fork the random state */ virtual int64_t ForkSeed() = 0; @@ -502,11 +504,11 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int64_t seed, int debug_mode, + TVM_DLL static Schedule Concrete(IRModule mod, Sampler::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); - TVM_DLL static Schedule Meta(IRModule mod, int64_t seed, int debug_mode, + TVM_DLL static Schedule Meta(IRModule mod, Sampler::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); - TVM_DLL static Schedule Traced(IRModule mod, int64_t seed, int debug_mode, + TVM_DLL static Schedule Traced(IRModule mod, Sampler::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/src/meta_schedule/cost_model/rand_cost_model.cc b/src/meta_schedule/cost_model/rand_cost_model.cc index 75304c3888..b6b9958738 100644 --- a/src/meta_schedule/cost_model/rand_cost_model.cc +++ b/src/meta_schedule/cost_model/rand_cost_model.cc @@ -61,10 +61,9 @@ class RandCostModelNode : public CostModelNode { */ class RandCostModel : public CostModel { public: - RandCostModel() { data_ = make_object(); } - - explicit RandCostModel(int seed) { + explicit RandCostModel(int seed = -1) { ObjectPtr n = make_object(); + if (seed == -1) seed = std::random_device()(); Sampler(&n->rand_state).Seed(seed); data_ = std::move(n); } @@ -76,8 +75,7 @@ class RandCostModel : public CostModel { struct Internal { static RandCostModel New(Optional seed) { - return seed.defined() ? RandCostModel(seed.value()->value) - : RandCostModel(std::random_device()()); + return seed.defined() ? RandCostModel(seed.value()->value) : RandCostModel(); } }; diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index be66eaeb52..a053ffc4ff 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -23,7 +23,7 @@ namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int64_t seed, int debug_mode, +Schedule Schedule::Concrete(IRModule mod, Sampler::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); @@ -181,7 +181,7 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb ScheduleCopier::Copy(this, new_state, new_symbol_table); } -Schedule ConcreteScheduleNode::Copy(int64_t new_seed) const { +Schedule ConcreteScheduleNode::Copy(Sampler::TRandState new_seed) const { ObjectPtr n = make_object(); Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; @@ -668,7 +668,7 @@ void ConcreteScheduleNode::SoftwarePipeline(const LoopRV& loop_rv, int num_stage TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](IRModule mod, int64_t seed, int debug_mode, + .set_body_typed([](IRModule mod, Sampler::TRandState seed, int debug_mode, int error_render_level) -> Schedule { return Schedule::Concrete(mod, seed, debug_mode, static_cast(error_render_level)); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 237662c576..ce66e75919 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -453,7 +453,7 @@ void TracedScheduleNode::InlineArgument(int i, const String& func_name) { TVM_REGISTER_NODE_TYPE(TracedScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") - .set_body_typed([](IRModule mod, int64_t seed, int debug_mode, + .set_body_typed([](IRModule mod, Sampler::TRandState seed, int debug_mode, int error_render_level) -> Schedule { return Schedule::Traced(mod, seed, debug_mode, static_cast(error_render_level)); From 9872d232fb0550b78502f06b76c55539fb7bc654 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 5 Aug 2021 16:48:38 -0700 Subject: [PATCH 12/23] Change type. --- src/tir/schedule/traced_schedule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ed6582060e..beacce0ddc 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,7 +47,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: Optional trace() const final { return trace_; } - Schedule Copy(int64_t new_seed = -1) const final; + Schedule Copy(Sampler::TRandState new_seed = -1) const final; public: /******** Schedule: Sampling ********/ From ce582e70620f37902df258636a716014a72fb4b1 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 10 Aug 2021 17:53:18 -0700 Subject: [PATCH 13/23] Dissolve sampler & fix LCE. --- include/tvm/support/random_engine.h | 6 +- include/tvm/tir/schedule/schedule.h | 16 +- src/meta_schedule/autotune.cc | 4 +- src/meta_schedule/autotune.h | 2 +- .../cost_model/rand_cost_model.cc | 6 +- src/meta_schedule/schedule.h | 3 - src/meta_schedule/search.cc | 20 +- src/meta_schedule/search.h | 11 +- src/meta_schedule/space/post_order_apply.cc | 19 +- src/meta_schedule/space/postproc.cc | 7 +- src/meta_schedule/space/postproc.h | 2 +- src/meta_schedule/space/schedule_fn.cc | 15 +- src/meta_schedule/strategy/evolutionary.cc | 66 +-- src/meta_schedule/strategy/mutator.cc | 48 +- src/meta_schedule/strategy/mutator.h | 3 +- src/meta_schedule/strategy/replay.cc | 12 +- src/tir/schedule/concrete_schedule.cc | 10 +- src/tir/schedule/concrete_schedule.h | 12 +- src/tir/schedule/primitive.h | 206 ++++++- src/tir/schedule/primitive/sampling.cc | 461 ++++++++++++++- src/tir/schedule/sampler.cc | 558 ------------------ src/tir/schedule/sampler.h | 160 ----- src/tir/schedule/traced_schedule.cc | 10 +- src/tir/schedule/traced_schedule.h | 2 +- tests/cpp/meta_schedule_test.cc | 8 +- 25 files changed, 786 insertions(+), 881 deletions(-) delete mode 100644 src/tir/schedule/sampler.cc delete mode 100644 src/tir/schedule/sampler.h diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index e73c1193f4..0889a383d6 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -19,7 +19,7 @@ /*! * \file random_engine.h - * \brief Random number generator, for Sampler and Sampling functions. + * \brief Random number generator, for Sampling functions. */ #ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ @@ -63,13 +63,13 @@ class LinearCongruentialEngine { * \brief The minimum possible value of random state here. * \note The function name is uncapilized because it is used in STL's distribution inferface. */ - result_type min() { return 0; } + static constexpr result_type min() { return 0; } /*! * \brief The maximum possible value of random state here. * \note The function name is uncapilized because it is used in STL's distribution inferface. */ - result_type max() { return modulus - 1; } + static constexpr result_type max() { return modulus - 1; } /*! * \brief Operator to move the random state to the next and return the new random state. According diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 195b3c351e..02829f37fd 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -19,13 +19,15 @@ #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ #define TVM_TIR_SCHEDULE_SCHEDULE_H_ +#include #include -#include "../../src/tir/schedule/sampler.h" - namespace tvm { namespace tir { +using TRandState = support::LinearCongruentialEngine::TRandState; +using RandEngine = support::LinearCongruentialEngine; + /*! \brief The level of detailed error message rendering */ enum class ScheduleErrorRenderLevel : int32_t { /*! \brief Render a detailed error message */ @@ -115,12 +117,12 @@ class ScheduleNode : public runtime::Object { * 3) All the random variables are valid in the copy, pointing to the correpsonding sref * reconstructed */ - virtual Schedule Copy(Sampler::TRandState seed = -1) const = 0; + virtual Schedule Copy(tir::TRandState seed = -1) const = 0; /*! * \brief Seed the randomness * \param seed The new random seed, -1 if use device random, otherwise non-negative */ - virtual void Seed(Sampler::TRandState seed = -1) = 0; + virtual void Seed(tir::TRandState seed = -1) = 0; /*! \brief Fork the random state */ virtual int64_t ForkSeed() = 0; @@ -504,11 +506,11 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, Sampler::TRandState seed, int debug_mode, + TVM_DLL static Schedule Concrete(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); - TVM_DLL static Schedule Meta(IRModule mod, Sampler::TRandState seed, int debug_mode, + TVM_DLL static Schedule Meta(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); - TVM_DLL static Schedule Traced(IRModule mod, Sampler::TRandState seed, int debug_mode, + TVM_DLL static Schedule Traced(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/src/meta_schedule/autotune.cc b/src/meta_schedule/autotune.cc index 7d6aec3e8a..1e028b15ce 100644 --- a/src/meta_schedule/autotune.cc +++ b/src/meta_schedule/autotune.cc @@ -25,9 +25,9 @@ namespace meta_schedule { void TuneContextNode::Init(Optional seed) { if (seed.defined() && seed.value() != -1) { - Sampler(&this->rand_state).Seed(seed.value()->value); + tir::RandEngine(&this->rand_state).Seed(seed.value()->value); } else { - Sampler(&this->rand_state).Seed(std::random_device()()); + tir::RandEngine(&this->rand_state).Seed(std::random_device()()); } if (task.defined()) { task.value()->Init(this); diff --git a/src/meta_schedule/autotune.h b/src/meta_schedule/autotune.h index f56973fa42..e61027a924 100644 --- a/src/meta_schedule/autotune.h +++ b/src/meta_schedule/autotune.h @@ -44,7 +44,7 @@ class TuneContextNode : public runtime::Object { Array measure_callbacks; int num_threads; - Sampler::TRandState rand_state; + tir::TRandState rand_state; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("task", &task); diff --git a/src/meta_schedule/cost_model/rand_cost_model.cc b/src/meta_schedule/cost_model/rand_cost_model.cc index b6b9958738..0afc722f7d 100644 --- a/src/meta_schedule/cost_model/rand_cost_model.cc +++ b/src/meta_schedule/cost_model/rand_cost_model.cc @@ -28,7 +28,7 @@ namespace meta_schedule { class RandCostModelNode : public CostModelNode { public: /*! \brief A random state for sampler to generate random numbers */ - Sampler::TRandState rand_state; + tir::TRandState rand_state; void VisitAttrs(tvm::AttrVisitor* v) { // sampler is not visited @@ -48,7 +48,7 @@ class RandCostModelNode : public CostModelNode { * \return The predicted scores for all states */ std::vector Predict(const SearchTask& task, const Array& states) override { - return Sampler(&rand_state).SampleUniform(states.size(), 0.0, 1.0); + return tir::SampleUniform(&rand_state, states.size(), 0.0, 1.0); } static constexpr const char* _type_key = "meta_schedule.RandCostModel"; @@ -64,7 +64,7 @@ class RandCostModel : public CostModel { explicit RandCostModel(int seed = -1) { ObjectPtr n = make_object(); if (seed == -1) seed = std::random_device()(); - Sampler(&n->rand_state).Seed(seed); + tir::RandEngine(&n->rand_state).Seed(seed); data_ = std::move(n); } diff --git a/src/meta_schedule/schedule.h b/src/meta_schedule/schedule.h index 96844ea325..cf94400ca6 100644 --- a/src/meta_schedule/schedule.h +++ b/src/meta_schedule/schedule.h @@ -22,14 +22,11 @@ #include #include -#include "../tir/schedule/sampler.h" - namespace tvm { namespace meta_schedule { using ScheduleNode = tir::TraceNode; using Schedule = tir::Schedule; -using Sampler = tir::Sampler; using BlockRV = tir::BlockRV; using BlockRVNode = tir::BlockRVNode; using LoopRV = tir::LoopRV; diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index 02fe13d2f0..a872e2dd75 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -58,9 +58,9 @@ SearchTask::SearchTask(tir::PrimFunc workload, String task_name, Target target, */ TVM_DLL Optional AutoTune(SearchTask task, SearchSpace space, SearchStrategy strategy, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } if (verbose) { @@ -108,9 +108,9 @@ struct Internal { */ static bool SearchSpacePostprocess(SearchSpace space, SearchTask task, Schedule sch, Optional seed) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return space->Postprocess(task, sch, &rand_state); } @@ -123,9 +123,9 @@ struct Internal { */ static Schedule SearchSpaceSampleSchedule(SearchSpace space, SearchTask task, Optional seed) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return space->SampleSchedule(task, &rand_state); } @@ -139,9 +139,9 @@ struct Internal { */ static Array SearchSpaceGetSupport(SearchSpace space, SearchTask task, Optional seed) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return space->GetSupport(task, &rand_state); } @@ -157,9 +157,9 @@ struct Internal { static Optional SearchStrategySearch(SearchStrategy strategy, SearchTask task, SearchSpace space, ProgramMeasurer measurer, Optional seed, int verbose) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return strategy->Search(task, space, measurer, &rand_state, verbose); } diff --git a/src/meta_schedule/search.h b/src/meta_schedule/search.h index 7b4cfa6859..3bc6571d8f 100644 --- a/src/meta_schedule/search.h +++ b/src/meta_schedule/search.h @@ -21,6 +21,7 @@ #include +#include "../tir/schedule/primitive.h" #include "./schedule.h" namespace tvm { @@ -104,20 +105,20 @@ class SearchSpaceNode : public runtime::Object { * \param rand_state The sampler's random state */ virtual bool Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandState* rand_state) = 0; + tir::TRandState* rand_state) = 0; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - virtual Schedule SampleSchedule(const SearchTask& task, Sampler::TRandState* rand_state) = 0; + virtual Schedule SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) = 0; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return The support of the search space. Any point from the search space should along to one of * the traces returned */ - virtual Array GetSupport(const SearchTask& task, Sampler::TRandState* rand_state) = 0; + virtual Array GetSupport(const SearchTask& task, tir::TRandState* rand_state) = 0; static constexpr const char* _type_key = "meta_schedule.SearchSpace"; TVM_DECLARE_BASE_OBJECT_INFO(SearchSpaceNode, Object); @@ -157,8 +158,8 @@ class SearchStrategyNode : public Object { * \return The best schedule found, NullOpt if no valid schedule is found */ virtual Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, - Sampler::TRandState* rand_state, int verbose) = 0; + const ProgramMeasurer& measurer, tir::TRandState* rand_state, + int verbose) = 0; /*! \brief Explore the search space */ virtual void Search() { LOG(FATAL) << "NotImplemented"; } diff --git a/src/meta_schedule/space/post_order_apply.cc b/src/meta_schedule/space/post_order_apply.cc index 9011b012ca..2ea4af08e9 100644 --- a/src/meta_schedule/space/post_order_apply.cc +++ b/src/meta_schedule/space/post_order_apply.cc @@ -52,20 +52,20 @@ class PostOrderApplyNode : public SearchSpaceNode { * \param rand_state The sampler's random state */ bool Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandState* rand_state) override; + tir::TRandState* rand_state) override; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - Schedule SampleSchedule(const SearchTask& task, Sampler::TRandState* rand_state) override; + Schedule SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) override; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return An array with a single element returned from SampleSchedule * \sa PostOrderApplyNode::SampleSchedule */ - Array GetSupport(const SearchTask& task, Sampler::TRandState* rand_state) override; + Array GetSupport(const SearchTask& task, tir::TRandState* rand_state) override; static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SearchSpaceNode); @@ -98,7 +98,7 @@ PostOrderApply::PostOrderApply(Array stages, Array postpro /********** Sampling **********/ bool PostOrderApplyNode::Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandState* rand_state) { + tir::TRandState* rand_state) { sch->EnterPostProc(); for (const Postproc& postproc : postprocs) { if (!postproc->Apply(task, sch, rand_state)) { @@ -108,11 +108,10 @@ bool PostOrderApplyNode::Postprocess(const SearchTask& task, const Schedule& sch return true; } -Schedule PostOrderApplyNode::SampleSchedule(const SearchTask& task, - Sampler::TRandState* rand_state) { +Schedule PostOrderApplyNode::SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) { Array support = GetSupport(task, rand_state); ICHECK(!support.empty()) << "ValueError: Found null support"; - int i = Sampler(rand_state).SampleInt(0, support.size()); + int i = tir::SampleInt(rand_state, 0, support.size()); return support[i]; } @@ -149,12 +148,12 @@ class BlockCollector : public tir::StmtVisitor { }; Array PostOrderApplyNode::GetSupport(const SearchTask& task, - Sampler::TRandState* rand_state) { + tir::TRandState* rand_state) { using ScheduleAndUnvisitedBlocks = std::pair>; Array curr{ Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/Sampler(rand_state).ForkSeed(), + /*seed=*/tir::ForkSeed(rand_state), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail)}; for (const SearchRule& rule : stages) { @@ -204,7 +203,7 @@ Array PostOrderApplyNode::GetSupport(const SearchTask& task, Trace trace = sch->trace().value()->Simplified(/*remove_postproc=*/true); Schedule new_sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/Sampler(rand_state).ForkSeed(), + /*seed=*/tir::ForkSeed(rand_state), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); trace->ApplyToSchedule(new_sch, /*remove_postproc=*/true); diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index 61ebb75ec0..5c32cf1da2 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -38,8 +38,7 @@ Postproc::Postproc(String name, FProc proc) { /********** Postproc **********/ -bool PostprocNode::Apply(const SearchTask& task, const Schedule& sch, - Sampler::TRandState* rand_state) { +bool PostprocNode::Apply(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state) { return proc_(task, sch, rand_state); } @@ -1119,9 +1118,9 @@ struct Internal { * \sa PostProcNode::Apply */ static bool Apply(Postproc self, SearchTask task, Schedule sch, Optional seed) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return self->Apply(task, sch, &rand_state); } diff --git a/src/meta_schedule/space/postproc.h b/src/meta_schedule/space/postproc.h index 3b388d4e7b..d9673fa48c 100644 --- a/src/meta_schedule/space/postproc.h +++ b/src/meta_schedule/space/postproc.h @@ -47,7 +47,7 @@ class PostprocNode : public Object { * \param rand_state The sampler's random state * \return If the post-processing succeeds */ - bool Apply(const SearchTask& task, const Schedule& sch, Sampler::TRandState* rand_state); + bool Apply(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state); static constexpr const char* _type_key = "meta_schedule.Postproc"; TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); diff --git a/src/meta_schedule/space/schedule_fn.cc b/src/meta_schedule/space/schedule_fn.cc index 902529e200..de179a9804 100644 --- a/src/meta_schedule/space/schedule_fn.cc +++ b/src/meta_schedule/space/schedule_fn.cc @@ -50,20 +50,20 @@ class ScheduleFnNode : public SearchSpaceNode { * \param rand_state The sampler's random state */ bool Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandState* rand_state) override; + tir::TRandState* rand_state) override; /*! * \brief Sample a schedule out of the search space * \param task The search task to be sampled from * \return The schedule sampled */ - Schedule SampleSchedule(const SearchTask& task, Sampler::TRandState* rand_state) override; + Schedule SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) override; /*! * \brief Get support of the search space * \param task The search task to be sampled from * \return An array with a single element returned from SampleSchedule * \sa ScheduleFnNode::SampleSchedule */ - Array GetSupport(const SearchTask& task, Sampler::TRandState* rand_state) override; + Array GetSupport(const SearchTask& task, tir::TRandState* rand_state) override; static constexpr const char* _type_key = "meta_schedule.ScheduleFn"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SearchSpaceNode); @@ -96,7 +96,7 @@ ScheduleFn::ScheduleFn(PackedFunc sch_fn, Array postprocs) { /********** Sampling **********/ bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, - Sampler::TRandState* rand_state) { + tir::TRandState* rand_state) { sch->EnterPostProc(); for (const Postproc& postproc : postprocs) { if (!postproc->Apply(task, sch, rand_state)) { @@ -106,17 +106,16 @@ bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, return true; } -Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, Sampler::TRandState* rand_state) { +Schedule ScheduleFnNode::SampleSchedule(const SearchTask& task, tir::TRandState* rand_state) { Schedule sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), task->workload}}), - /*seed=*/Sampler(rand_state).ForkSeed(), + /*seed=*/tir::ForkSeed(rand_state), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); this->sch_fn_(sch); return sch; } -Array ScheduleFnNode::GetSupport(const SearchTask& task, - Sampler::TRandState* rand_state) { +Array ScheduleFnNode::GetSupport(const SearchTask& task, tir::TRandState* rand_state) { return {SampleSchedule(task, rand_state)}; } diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index 5bfa415d23..eafceb6695 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -139,7 +139,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \return The best schedule found, NullOpt if no valid schedule is found */ Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler::TRandState* rand_state, + const ProgramMeasurer& measurer, tir::TRandState* rand_state, int verbose) override; /********** Stages in evolutionary search **********/ @@ -155,7 +155,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \return The generated samples, all of which are not post-processed */ Array SampleInitPopulation(const Array& support, const SearchTask& task, - const SearchSpace& space, Sampler::TRandState* rand_state); + const SearchSpace& space, tir::TRandState* rand_state); /*! * \brief Perform evolutionary search using genetic algorithm with the cost model @@ -166,7 +166,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \return An array of schedules, the sampling result */ Array EvolveWithCostModel(const Array& inits, const SearchTask& task, - const SearchSpace& space, Sampler::TRandState* rand_state); + const SearchSpace& space, tir::TRandState* rand_state); /*! * \brief Pick a batch of samples for measurement with epsilon greedy @@ -179,7 +179,7 @@ class EvolutionaryNode : public SearchStrategyNode { */ Array PickWithEpsGreedy(const Array& inits, const Array& bests, const SearchTask& task, const SearchSpace& space, - Sampler::TRandState* rand_state); + tir::TRandState* rand_state); /*! * \brief Make measurements and update the cost model @@ -205,11 +205,11 @@ class EvolutionaryNode : public SearchStrategyNode { * \param rand_state The sampler's random state * \return A list of random states, the result of forking */ - static std::vector ForkSamplers(int n, Sampler::TRandState* rand_state) { - std::vector result; + static std::vector ForkSamplers(int n, tir::TRandState* rand_state) { + std::vector result; result.reserve(n); for (int i = 0; i < n; ++i) { - result.emplace_back(Sampler(rand_state).ForkSeed()); + result.emplace_back(tir::ForkSeed(rand_state)); } return result; } @@ -228,10 +228,10 @@ class EvolutionaryNode : public SearchStrategyNode { * \brief Replay the trace and do postprocessing */ static Optional ReplayTrace(const Trace& trace, const SearchTask& task, - const SearchSpace& space, Sampler::TRandState* rand_state, + const SearchSpace& space, tir::TRandState* rand_state, const tir::PrimFunc& workload) { Schedule sch = Schedule::Traced(/*mod=*/IRModule({{GlobalVar("main"), workload}}), - /*seed=*/Sampler(rand_state).ForkSeed(), + /*seed=*/tir::ForkSeed(rand_state), /*debug_mode=*/false, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); trace->ApplyToSchedule(sch, /*remove_postproc=*/true); @@ -247,8 +247,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \return The sampler created */ static std::function()> MakeMutatorSampler( - double p_mutate, const Map& mutator_probs, - Sampler::TRandState* rand_state) { + double p_mutate, const Map& mutator_probs, tir::TRandState* rand_state) { CHECK(0.0 <= p_mutate && p_mutate <= 1.0) // << "ValueError: Probability should be within [0, 1], " << "but get `p_mutate = " << p_mutate << '\''; @@ -277,7 +276,7 @@ class EvolutionaryNode : public SearchStrategyNode { masses[i] /= total_mass_mutator; } } - auto idx_sampler = Sampler(rand_state).MakeMultinomial(masses); + auto idx_sampler = tir::MakeMultinomial(rand_state, masses); return [idx_sampler = std::move(idx_sampler), mutators = std::move(mutators)]() -> Optional { int i = idx_sampler(); @@ -424,7 +423,7 @@ Evolutionary::Evolutionary(int total_measures, int num_measures_per_iteration, i CHECK_LE(num_measures_per_iteration, population) << "ValueError: requires `num_measures_per_iteration <= population`"; { - Sampler::TRandState rand_state = 42; + tir::TRandState rand_state = 42; EvolutionaryNode::MakeMutatorSampler(p_mutate, mutator_probs, &rand_state); } ObjectPtr n = make_object(); @@ -445,7 +444,7 @@ Evolutionary::Evolutionary(int total_measures, int num_measures_per_iteration, i Optional EvolutionaryNode::Search(const SearchTask& task, const SearchSpace& space, const ProgramMeasurer& measurer, - Sampler::TRandState* rand_state, int verbose) { + tir::TRandState* rand_state, int verbose) { Array support = space->GetSupport(task, rand_state); int iter = 1; for (int num_measured = 0; num_measured < this->total_measures; ++iter) { @@ -472,14 +471,13 @@ Optional EvolutionaryNode::Search(const SearchTask& task, const Search Array EvolutionaryNode::SampleInitPopulation(const Array& support, const SearchTask& task, const SearchSpace& space, - Sampler::TRandState* global_rand_state) { + tir::TRandState* global_rand_state) { trace_cache_.clear(); std::vector results; results.reserve(this->population); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_rand_states = - ForkSamplers(num_threads, global_rand_state); + std::vector thread_rand_states = ForkSamplers(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); // Pick measured states int num_measured = this->population * this->init_measured_ratio; @@ -488,7 +486,7 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo } auto f_proc_measured = [this, &results, &thread_rand_states, &task, &space, thread_workloads]( int thread_id, int i) -> void { - Sampler::TRandState* rand_state = &thread_rand_states[thread_id]; + tir::TRandState* rand_state = &thread_rand_states[thread_id]; const Trace& trace = results[i]; if (Optional opt_sch = ReplayTrace(trace, task, space, rand_state, thread_workloads[thread_id])) { @@ -505,10 +503,9 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo std::atomic success_ct(0); auto f_proc_unmeasured = [this, &results, &thread_rand_states, &tot_fail_ct, &task, &space, &support, &success_ct, thread_workloads](int thread_id, int i) -> void { - Sampler::TRandState* rand_state = &thread_rand_states[thread_id]; + tir::TRandState* rand_state = &thread_rand_states[thread_id]; for (;;) { - Trace support_trace = - support[Sampler(rand_state).SampleInt(0, support.size())]->trace().value(); + Trace support_trace = support[tir::SampleInt(rand_state, 0, support.size())]->trace().value(); Map decisions; try { if (Optional opt_sch = @@ -547,14 +544,13 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, const SearchTask& task, const SearchSpace& space, - Sampler::TRandState* global_rand_state) { + tir::TRandState* global_rand_state) { // The heap to record best schedule, we do not consider schedules that are already measured // Also we use `in_heap` to make sure items in the heap are de-duplicated SizedHeap heap(this->num_measures_per_iteration); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_rand_states = - ForkSamplers(num_threads, global_rand_state); + std::vector thread_rand_states = ForkSamplers(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); std::vector> thread_trace_samplers(num_threads); std::vector()>> thread_mutator_samplers(num_threads); @@ -563,8 +559,8 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, auto f_set_sampler = [this, num_threads, &thread_rand_states, &thread_trace_samplers, &thread_mutator_samplers, &trace_used](const std::vector& scores) { for (int i = 0; i < num_threads; ++i) { - Sampler::TRandState* rand_state = &thread_rand_states[i]; - thread_trace_samplers[i] = Sampler(rand_state).MakeMultinomial(scores); + tir::TRandState* rand_state = &thread_rand_states[i]; + thread_trace_samplers[i] = tir::MakeMultinomial(rand_state, scores); thread_mutator_samplers[i] = MakeMutatorSampler(this->p_mutate, this->mutator_probs, rand_state); } @@ -601,7 +597,7 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, &trace_used, &trace_used_mutex, &sch_curr, &sch_next, &task, &space, thread_workloads, this](int thread_id, int i) { // Prepare samplers - Sampler::TRandState* rand_state = &thread_rand_states[thread_id]; + tir::TRandState* rand_state = &thread_rand_states[thread_id]; const std::function& trace_sampler = thread_trace_samplers[thread_id]; const std::function()>& mutator_sampler = thread_mutator_samplers[thread_id]; @@ -678,10 +674,10 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, Array EvolutionaryNode::PickWithEpsGreedy(const Array& inits, const Array& bests, const SearchTask& task, const SearchSpace& space, - Sampler::TRandState* rand_state) { + tir::TRandState* rand_state) { int num_rands = this->num_measures_per_iteration * this->eps_greedy; int num_bests = this->num_measures_per_iteration - num_rands; - std::vector rands = Sampler(rand_state).SampleWithoutReplacement(inits.size(), inits.size()); + std::vector rands = tir::SampleWithoutReplacement(rand_state, inits.size(), inits.size()); Array results; results.reserve(this->num_measures_per_iteration); for (int i = 0, i_bests = 0, i_rands = 0; i < this->num_measures_per_iteration; ++i) { @@ -783,9 +779,9 @@ struct Internal { static Array SampleInitPopulation(Evolutionary self, Array support, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return self->SampleInitPopulation(support, task, space, &rand_state); } @@ -801,9 +797,9 @@ struct Internal { */ static Array EvolveWithCostModel(Evolutionary self, Array inits, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return self->EvolveWithCostModel(inits, task, space, &rand_state); } @@ -819,9 +815,9 @@ struct Internal { static Array PickWithEpsGreedy(Evolutionary self, Array inits, Array bests, SearchTask task, SearchSpace space, Optional seed) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return self->PickWithEpsGreedy(inits, bests, task, space, &rand_state); } diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index fd5b7cb37c..c6d39b452e 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -36,7 +36,7 @@ Mutator::Mutator(String name, FApply apply) { /********** Mutator **********/ Optional MutatorNode::Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandState* rand_state) { + tir::TRandState* rand_state) { return apply_(task, trace, rand_state); } @@ -78,18 +78,17 @@ class MutatorTileSize { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandState* rand_state) { + Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state) { // Find instruction `SamplePerfectTile` whose extent > 1 and n_splits > 1 std::vector candidates = FindCandidates(trace); if (candidates.empty()) { return NullOpt; } - const Instruction& inst = candidates[Sampler(rand_state).SampleInt(0, candidates.size())]; + const Instruction& inst = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; std::vector tiles = CastDecision(trace->decisions.at(inst)); int n_splits = tiles.size(); // Choose two loops - int x = Sampler(rand_state).SampleInt(0, n_splits); + int x = tir::SampleInt(rand_state, 0, n_splits); int y; if (tiles[x] == 1) { // need to guarantee that tiles[x] * tiles[y] > 1 @@ -100,10 +99,10 @@ class MutatorTileSize { idx.push_back(i); } } - y = idx[Sampler(rand_state).SampleInt(0, idx.size())]; + y = idx[tir::SampleInt(rand_state, 0, idx.size())]; } else { // sample without replacement - y = Sampler(rand_state).SampleInt(0, n_splits - 1); + y = tir::SampleInt(rand_state, 0, n_splits - 1); if (y >= x) { ++y; } @@ -117,7 +116,7 @@ class MutatorTileSize { int len_x, len_y; if (y != n_splits - 1) { do { - std::vector result = Sampler(rand_state).SamplePerfectTile(2, tiles[x] * tiles[y]); + std::vector result = tir::SamplePerfectTile(rand_state, 2, tiles[x] * tiles[y]); len_x = result[0]; len_y = result[1]; } while (len_y == tiles[y]); @@ -134,7 +133,7 @@ class MutatorTileSize { if (len_y_space.empty()) { return NullOpt; } - len_y = len_y_space[Sampler(rand_state).SampleInt(0, len_y_space.size())]; + len_y = len_y_space[tir::SampleInt(rand_state, 0, len_y_space.size())]; len_x = prod / len_y; } tiles[x] = len_x; @@ -146,7 +145,7 @@ class MutatorTileSize { Mutator MutateTileSize() { auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorTileSize mutator; - return mutator.Apply(task, trace, static_cast(rand_state)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_tile_size", f_apply); } @@ -218,14 +217,13 @@ class MutatorComputeLocation { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandState* rand_state) { + Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state) { std::vector candidates = FindCandidates(trace, task->workload); if (candidates.empty()) { return NullOpt; } - const Candidate& candidate = candidates[Sampler(rand_state).SampleInt(0, candidates.size())]; - int loc = candidate.locs[Sampler(rand_state).SampleInt(0, candidate.locs.size())]; + const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; + int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())]; return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true); } }; @@ -233,7 +231,7 @@ class MutatorComputeLocation { Mutator MutateComputeLocation() { auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorComputeLocation mutator; - return mutator.Apply(task, trace, static_cast(rand_state)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_compute_location", f_apply); } @@ -311,14 +309,13 @@ class MutatorAutoUnroll { return candidates; } - Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandState* rand_state) { + Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state) { std::vector candidates = FindCandidates(trace); if (candidates.empty()) { return NullOpt; } - const Candidate& candidate = candidates[Sampler(rand_state).SampleInt(0, candidates.size())]; - int result = Sampler(rand_state).MakeMultinomial(candidate.weights)(); + const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; + int result = tir::MakeMultinomial(rand_state, candidate.weights)(); if (result >= candidate.ori_decision) { result++; } @@ -329,7 +326,7 @@ class MutatorAutoUnroll { Mutator MutateAutoUnroll() { auto f_apply = [](SearchTask task, Trace trace, void* rand_state) -> Optional { MutatorAutoUnroll mutator; - return mutator.Apply(task, trace, static_cast(rand_state)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_unroll_depth", f_apply); } @@ -434,7 +431,7 @@ class MutatorParallel { } Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandState* rand_state) const { + tir::TRandState* rand_state) const { static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); int max_extent = GetTargetNumCores(task->target, &warned_num_cores_missing) * max_jobs_per_core - 1; @@ -444,8 +441,7 @@ class MutatorParallel { } const BlockRV& block = Downcast(candidate.inst->inputs[0]); const std::vector& extent_candidates = candidate.extent_candidates; - int parallel_size = - extent_candidates[Sampler(rand_state).SampleInt(0, extent_candidates.size())]; + int parallel_size = extent_candidates[tir::SampleInt(rand_state, 0, extent_candidates.size())]; std::vector new_insts; for (const Instruction& inst : trace->insts) { @@ -471,7 +467,7 @@ class MutatorParallel { Mutator MutateParallel(const int& max_jobs_per_core) { MutatorParallel mutator(max_jobs_per_core); auto f_apply = [mutator](SearchTask task, Trace trace, void* rand_state) -> Optional { - return mutator.Apply(task, trace, static_cast(rand_state)); + return mutator.Apply(task, trace, static_cast(rand_state)); }; return Mutator("mutate_parallel", f_apply); } @@ -485,9 +481,9 @@ struct Internal { */ static Optional Apply(Mutator mutator, SearchTask task, Trace trace, Optional seed) { - Sampler::TRandState rand_state = std::random_device()(); + tir::TRandState rand_state = std::random_device()(); if (seed.defined()) { - Sampler(&rand_state).Seed(seed.value()); + tir::RandEngine(&rand_state).Seed(seed.value()); } return mutator->Apply(task, trace, &rand_state); } diff --git a/src/meta_schedule/strategy/mutator.h b/src/meta_schedule/strategy/mutator.h index 8c2ef293f7..dae5cd30b3 100644 --- a/src/meta_schedule/strategy/mutator.h +++ b/src/meta_schedule/strategy/mutator.h @@ -47,8 +47,7 @@ class MutatorNode : public Object { * \param rand_state The sampler's random state * \return The new schedule after mutation, NullOpt if mutation fails */ - Optional Apply(const SearchTask& task, const Trace& trace, - Sampler::TRandState* rand_state); + Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state); static constexpr const char* _type_key = "meta_schedule.Mutator"; TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); diff --git a/src/meta_schedule/strategy/replay.cc b/src/meta_schedule/strategy/replay.cc index 0ea85e5e0b..0a4fb9a691 100644 --- a/src/meta_schedule/strategy/replay.cc +++ b/src/meta_schedule/strategy/replay.cc @@ -51,7 +51,7 @@ class ReplayNode : public SearchStrategyNode { * \return The best schedule found, NullOpt if no valid schedule is found */ Optional Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, Sampler::TRandState* rand_state, + const ProgramMeasurer& measurer, tir::TRandState* rand_state, int verbose) override; static constexpr const char* _type_key = "meta_schedule.Replay"; @@ -86,18 +86,18 @@ Replay::Replay(int batch_size, int num_trials) { /********** Search **********/ Optional ReplayNode::Search(const SearchTask& task, const SearchSpace& space, - const ProgramMeasurer& measurer, - Sampler::TRandState* rand_state, int verbose) { - std::vector thread_rand_states; + const ProgramMeasurer& measurer, tir::TRandState* rand_state, + int verbose) { + std::vector thread_rand_states; std::vector thread_measure_inputs; thread_rand_states.reserve(this->batch_size); thread_measure_inputs.reserve(this->batch_size); for (int i = 0; i < batch_size; ++i) { - thread_rand_states.emplace_back(Sampler(rand_state).ForkSeed()); + thread_rand_states.emplace_back(tir::ForkSeed(rand_state)); thread_measure_inputs.emplace_back(nullptr); } auto worker = [&task, &space, &thread_rand_states, &thread_measure_inputs](int thread_id, int i) { - Sampler::TRandState* rand_state = &thread_rand_states[i]; + tir::TRandState* rand_state = &thread_rand_states[i]; for (;;) { Schedule sch = space->SampleSchedule(task, rand_state); if (space->Postprocess(task, sch, rand_state)) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index a053ffc4ff..56aa63d4c2 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -23,13 +23,13 @@ namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, Sampler::TRandState seed, int debug_mode, +Schedule Schedule::Concrete(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); n->error_render_level_ = error_render_level; if (seed == -1) seed = std::random_device()(); - Sampler(&n->rand_state_).Seed(seed); + tir::RandEngine(&n->rand_state_).Seed(seed); n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); @@ -181,13 +181,13 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb ScheduleCopier::Copy(this, new_state, new_symbol_table); } -Schedule ConcreteScheduleNode::Copy(Sampler::TRandState new_seed) const { +Schedule ConcreteScheduleNode::Copy(tir::TRandState new_seed) const { ObjectPtr n = make_object(); Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; n->analyzer_ = std::make_unique(); if (new_seed == -1) new_seed = std::random_device()(); - Sampler(&n->rand_state_).Seed(new_seed); + tir::RandEngine(&n->rand_state_).Seed(new_seed); return Schedule(std::move(n)); } @@ -668,7 +668,7 @@ void ConcreteScheduleNode::SoftwarePipeline(const LoopRV& loop_rv, int num_stage TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](IRModule mod, Sampler::TRandState seed, int debug_mode, + .set_body_typed([](IRModule mod, tir::TRandState seed, int debug_mode, int error_render_level) -> Schedule { return Schedule::Concrete(mod, seed, debug_mode, static_cast(error_render_level)); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 2ca6910e85..e274b633b1 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -23,7 +23,7 @@ #include #include -#include "./sampler.h" +#include "./primitive.h" #include "./utils.h" namespace tvm { @@ -42,7 +42,7 @@ class ConcreteScheduleNode : public ScheduleNode { /*! \brief The level of error rendering */ ScheduleErrorRenderLevel error_render_level_; /*! \brief Source of randomness */ - Sampler::TRandState rand_state_; + tir::TRandState rand_state_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ @@ -66,12 +66,12 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } - Schedule Copy(Sampler::TRandState new_seed = -1) const override; - void Seed(Sampler::TRandState new_seed = -1) final { + Schedule Copy(tir::TRandState new_seed = -1) const override; + void Seed(tir::TRandState new_seed = -1) final { if (new_seed == -1) new_seed = std::random_device()(); - Sampler(&this->rand_state_).Seed(new_seed); + RandEngine(&this->rand_state_).Seed(new_seed); } - Sampler::TRandState ForkSeed() final { return Sampler(&this->rand_state_).ForkSeed(); } + tir::TRandState ForkSeed() final { return tir::ForkSeed(&this->rand_state_); } public: /******** Lookup random variables ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 408ef4e2eb..561812e81e 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -19,29 +19,219 @@ #ifndef TVM_TIR_SCHEDULE_PRIMITIVES_PRIMITIVES_H_ #define TVM_TIR_SCHEDULE_PRIMITIVES_PRIMITIVES_H_ +#include +#include #include +#include #include -#include "sampler.h" - namespace tvm { namespace tir { -class Sampler; +struct PrimeTable { + /*! \brief The table contains prime numbers in [2, kMaxPrime) */ + static constexpr const int kMaxPrime = 65536; + /*! \brief The exact number of prime numbers in the table */ + static constexpr const int kNumPrimes = 6542; + /*! + * \brief For each number in [2, kMaxPrime), the index of its min factor. + * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. + */ + int min_factor_idx[kMaxPrime]; + /*! \brief The prime numbers in [2, kMaxPrime) */ + std::vector primes; + /*! + * \brief The power of each prime number. + * pow_table[i, j] stores the result of pow(prime[i], j + 1) + */ + std::vector> pow_tab; + + /*! \brief Get a global instance of the prime table */ + static const PrimeTable* Global() { + static const PrimeTable table; + return &table; + } + + /*! \brief Constructor, pre-computes all info in the prime table */ + PrimeTable() { + constexpr const int64_t int_max = std::numeric_limits::max(); + // Euler's sieve: prime number in linear time + for (int i = 0; i < kMaxPrime; ++i) { + min_factor_idx[i] = -1; + } + primes.reserve(kNumPrimes); + for (int x = 2; x < kMaxPrime; ++x) { + if (min_factor_idx[x] == -1) { + min_factor_idx[x] = primes.size(); + primes.push_back(x); + } + for (size_t i = 0; i < primes.size(); ++i) { + int factor = primes[i]; + int y = x * factor; + if (y >= kMaxPrime) { + break; + } + min_factor_idx[y] = i; + if (x % factor == 0) { + break; + } + } + } + ICHECK_EQ(static_cast(primes.size()), int(kNumPrimes)); + // Calculate the power table for each prime number + pow_tab.reserve(primes.size()); + for (int prime : primes) { + std::vector tab; + tab.reserve(32); + for (int64_t pow = prime; pow <= int_max; pow *= prime) { + tab.push_back(pow); + } + tab.shrink_to_fit(); + pow_tab.emplace_back(std::move(tab)); + } + } + /*! + * \brief Factorize a number n, and return in a cryptic format + * \param n The number to be factorized + * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] + * For each pair (i, j), we define + * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) + * (primes[i], j) if i != -1 + * Then the factorization is + * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) + */ + std::vector> Factorize(int n) const { + std::vector> result; + result.reserve(16); + int i = 0, n_primes = primes.size(); + // Phase 1: n >= kMaxPrime + for (int j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + if (j != 0) { + result.emplace_back(i, j); + } + } + // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number + if (n >= kMaxPrime) { + result.emplace_back(-1, n); + return result; + } + // Phase 2: n < kMaxPrime + for (int j; n > 1;) { + int i = min_factor_idx[n]; + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + result.emplace_back(i, j); + } + return result; + } +}; /******** Schedule: Sampling ********/ -TVM_DLL std::vector SamplePerfectTile(tir::ScheduleState self, - Sampler::TRandState* rand_state, +/*! \brief Return a seed that can be used to create a new sampler */ +TRandState ForkSeed(TRandState* rand_state); +/*! + * \brief Sample an integer in [min_inclusive, max_exclusive) + * \param min_inclusive The left boundary, inclusive + * \param max_exclusive The right boundary, exclusive + * \return The integer sampled + */ +int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive); +/*! + * \brief Sample n integers in [min_inclusive, max_exclusive) + * \param min_inclusive The left boundary, inclusive + * \param max_exclusive The right boundary, exclusive + * \return The list of integers sampled + */ +std::vector SampleInts(TRandState* rand_state, int n, int min_inclusive, int max_exclusive); +/*! + * \brief Random shuffle from the begin iterator to the end. + * \param begin_it The begin iterator + * \param end_it The end iterator + */ +template +void SampleShuffle(TRandState* rand_state, RandomAccessIterator begin_it, + RandomAccessIterator end_it); +/*! + * \brief Sample n tiling factors of the specific extent + * \param n The number of parts the loop is split + * \param extent Length of the loop + * \param candidates The possible tiling factors + * \return A list of length n, the tiling factors sampled + */ +std::vector SampleTileFactor(TRandState* rand_state, int n, int extent, + const std::vector& candidates); +/*! + * \brief Sample perfect tiling factor of the specific extent + * \param n_splits The number of parts the loop is split + * \param extent Length of the loop + * \return A list of length n_splits, the tiling factors sampled, the product of which strictly + * equals to extent + */ +std::vector SamplePerfectTile(TRandState* rand_state, int n_splits, int extent); +/*! + * \brief Sample perfect tiling factor of the specific extent + * \param n_splits The number of parts the loop is split + * \param extent Length of the loop + * \param max_innermost_factor A small number indicating the max length of the innermost loop + * \return A list of length n_splits, the tiling factors sampled, the product of which strictly + * equals to extent + */ +std::vector SamplePerfectTile(TRandState* rand_state, int n_splits, int extent, + int max_innermost_factor); +/*! + * \brief Sample shape-generic tiling factors that are determined by the hardware constraints. + * \param n_splits The number of parts the loops are split + * \param max_extents Maximum length of the loops + * \param is_spatial Whether each loop is a spatial axis or not + * \param target Hardware target + * \param max_innermost_factor A small number indicating the max length of the innermost loop + * \return A list of list of length n_splits, the tiling factors sampled, all satisfying the + * maximum extents and the hardware constraints + */ +std::vector> SampleShapeGenericTiles(TRandState* rand_state, + const std::vector& n_splits, + const std::vector& max_extents, + const Target& target, + int max_innermost_factor); +/*! + * \brief Sample n floats uniformly in [min, max) + * \param min The left boundary + * \param max The right boundary + * \return The list of floats sampled + */ +std::vector SampleUniform(TRandState* rand_state, int n, double min, double max); +/*! + * \brief Sample from a Bernoulli distribution + * \param p Parameter in the Bernoulli distribution + * \return return true with probability p, and false with probability (1 - p) + */ +bool SampleBernoulli(TRandState* rand_state, double p); +/*! + * \brief Create a multinomial sampler based on the specific weights + * \param weights The weights, event probabilities + * \return The multinomial sampler + */ +std::function MakeMultinomial(TRandState* rand_state, const std::vector& weights); +/*! + * \brief Classic sampling without replacement + * \param n The population size + * \param k The number of samples to be drawn from the population + * \return A list of indices, samples drawn, unsorted and index starting from 0 + */ +std::vector SampleWithoutReplacement(TRandState* rand_state, int n, int k); + +TVM_DLL std::vector SamplePerfectTile(tir::ScheduleState self, tir::TRandState* rand_state, const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision); -TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandState* rand_state, +TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, tir::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); -TVM_DLL tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, - Sampler::TRandState* rand_state, +TVM_DLL tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, tir::TRandState* rand_state, const tir::StmtSRef& block_sref, Optional* decision); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 26624c4030..7057527ef9 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -16,13 +16,459 @@ * specific language governing permissions and limitations * under the License. */ -#include "../sampler.h" +#include +#include + #include "../utils.h" namespace tvm { namespace tir { -std::vector SamplePerfectTile(tir::ScheduleState self, Sampler::TRandState* rand_state, +TRandState ForkSeed(TRandState* rand_state) { + // In order for reproducibility, we computer the new seed using sampler's RNG's random state and a + // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. + TRandState ret = (RandEngine(rand_state)() * 32767) % 1999999973; + return ret; +} + +int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive) { + RandEngine rand_(rand_state); + + if (min_inclusive + 1 == max_exclusive) { + return min_inclusive; + } + std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); + return dist(rand_); +} + +std::vector SampleInts(TRandState* rand_state, int n, int min_inclusive, int max_exclusive) { + RandEngine rand_(rand_state); + std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); + std::vector result; + result.reserve(n); + for (int i = 0; i < n; ++i) { + result.push_back(dist(rand_)); + } + return result; +} + +template +void SampleShuffle(TRandState* rand_state, RandomAccessIterator begin_it, + RandomAccessIterator end_it) { + RandEngine rand_(rand_state); + std::shuffle(begin_it, end_it, rand_); +} + +std::vector SampleUniform(TRandState* rand_state, int n, double min, double max) { + RandEngine rand_(rand_state); + std::uniform_real_distribution dist(min, max); + std::vector result; + result.reserve(n); + for (int i = 0; i < n; ++i) { + result.push_back(dist(rand_)); + } + return result; +} + +bool SampleBernoulli(TRandState* rand_state, double p) { + RandEngine rand_(rand_state); + std::bernoulli_distribution dist(p); + return dist(rand_); +} + +std::function MakeMultinomial(TRandState* rand_state, const std::vector& weights) { + RandEngine rand_(rand_state); + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + std::uniform_real_distribution dist(0.0, sum); + auto sampler = [rand_state, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { + RandEngine rand_(rand_state); + double p = dist(rand_); + int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; + return sampler; +} + +std::vector SampleWithoutReplacement(TRandState* rand_state, int n, int k) { + if (k == 1) { + return {SampleInt(rand_state, 0, n)}; + } + if (k == 2) { + int result0 = SampleInt(rand_state, 0, n); + int result1 = SampleInt(rand_state, 0, n - 1); + if (result1 >= result0) { + result1 += 1; + } + return {result0, result1}; + } + std::vector order(n); + for (int i = 0; i < n; ++i) { + order[i] = i; + } + for (int i = 0; i < k; ++i) { + int j = SampleInt(rand_state, i, n); + if (i != j) { + std::swap(order[i], order[j]); + } + } + return {order.begin(), order.begin() + k}; +} + +std::vector SampleTileFactor(TRandState* rand_state, int n, int extent, + const std::vector& candidates) { + RandEngine rand_(rand_state); + constexpr int kMaxTrials = 100; + std::uniform_int_distribution<> dist(0, static_cast(candidates.size()) - 1); + std::vector sample(n, -1); + for (int trial = 0; trial < kMaxTrials; ++trial) { + int64_t product = 1; + for (int i = 1; i < n; ++i) { + int value = candidates[dist(rand_)]; + product *= value; + if (product > extent) { + break; + } + sample[i] = value; + } + if (product <= extent) { + sample[0] = (extent + product - 1) / product; + return sample; + } + } + sample[0] = extent; + for (int i = 1; i < n; ++i) { + sample[i] = 1; + } + return sample; +} + +std::vector SamplePerfectTile(TRandState* rand_state, int n_splits, int extent) { + CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; + CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; + // Handle special case that we can potentially accelerate + if (n_splits == 1) { + return {extent}; + } + if (extent == 1) { + return std::vector(n_splits, 1); + } + // Enumerate each pair (i, j), we define + // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) + // (primes[i], j) if i != -1 + // Then the factorization is + // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) + const PrimeTable* prime_tab = PrimeTable::Global(); + std::vector> factorized = prime_tab->Factorize(extent); + if (n_splits == 2) { + // n_splits = 2, this can be taken special care of, + // because general reservoir sampling can be avoided to accelerate the sampling + int result0 = 1; + int result1 = 1; + for (const std::pair& ij : factorized) { + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + (SampleInt(rand_state, 0, 2) ? result1 : result0) *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int p = ij.second; + const int* pow = prime_tab->pow_tab[ij.first].data() - 1; + int x1 = SampleInt(rand_state, 0, p + 1); + int x2 = p - x1; + if (x1 != 0) { + result0 *= pow[x1]; + } + if (x2 != 0) { + result1 *= pow[x2]; + } + } + return {result0, result1}; + } + // Data range: + // 2 <= extent <= 2^31 - 1 + // 3 <= n_splits <= max tiling splits + // 1 <= p <= 31 + std::vector result(n_splits, 1); + for (const std::pair& ij : factorized) { + // Handle special cases to accelerate sampling + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + result[SampleInt(rand_state, 0, n_splits)] *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int p = ij.second; + if (p == 1) { + result[SampleInt(rand_state, 0, n_splits)] *= prime_tab->primes[ij.first]; + continue; + } + // The general case. We have to sample uniformly from the solution of: + // x_1 + x_2 + ... + x_{n_splits} = p + // where x_i >= 0 + // Data range: + // 2 <= p <= 31 + // 3 <= n_splits <= max tiling splits + std::vector sampled = SampleWithoutReplacement(rand_state, p + n_splits - 1, n_splits - 1); + std::sort(sampled.begin(), sampled.end()); + sampled.push_back(p + n_splits - 1); + const int* pow = prime_tab->pow_tab[ij.first].data() - 1; + for (int i = 0, last = -1; i < n_splits; ++i) { + int x = sampled[i] - last - 1; + last = sampled[i]; + if (x != 0) { + result[i] *= pow[x]; + } + } + } + return result; +} + +std::vector SamplePerfectTile(TRandState* rand_state, int n_splits, int extent, + int max_innermost_factor) { + if (max_innermost_factor == -1) { + return SamplePerfectTile(rand_state, n_splits, extent); + } + CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; + std::vector innermost_candidates; + innermost_candidates.reserve(max_innermost_factor); + for (int i = 1; i <= max_innermost_factor; ++i) { + if (extent % i == 0) { + innermost_candidates.push_back(i); + } + } + // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. + // We should do multiple factorization to weight the choices. However, it would lead to slower + // sampling speed. On the other hand, considering potential tricks we might do on the innermost + // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add + // more heuristics in the future + int innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; + std::vector result = SamplePerfectTile(rand_state, n_splits - 1, extent / innermost); + result.push_back(innermost); + return result; +} + +static inline int ExtractInt(const Target& target, const char* name) { + if (Optional v = target->GetAttr(name)) { + return v.value()->value; + } + LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; + throw; +} + +static inline bool IsCudaTarget(const Target& target) { + if (Optional v = target->GetAttr("kind")) { + return v.value() == "cuda"; + } + return false; +} + +std::vector> SampleShapeGenericTiles(TRandState* rand_state, + const std::vector& n_splits, + const std::vector& max_extents, + const Target& target, + int max_innermost_factor) { + std::vector> ret_split_factors; + + if (IsCudaTarget(target)) { + // The following factorization scheme is built under the assumption that: (1) The target is + // CUDA, and (2) The tiling structure is SSSRRSRS. + + // extract all the hardware parameters + const struct HardwareConstraints { + int max_shared_memory_per_block; + int max_local_memory_per_block; + int max_threads_per_block; + int max_innermost_factor; + int max_vthread; + } constraints = {ExtractInt(target, "shared_memory_per_block"), + ExtractInt(target, "registers_per_block"), + ExtractInt(target, "max_threads_per_block"), max_innermost_factor, 8}; + + for (const int n_split : n_splits) { + ret_split_factors.push_back(std::vector(n_split, 1)); + } + + // sample the number of threads per block + const int warp_size = ExtractInt(target, "warp_size"); + int num_threads_per_block = + SampleInt(rand_state, 1, constraints.max_threads_per_block / warp_size) * warp_size; + // find all the possible factors of the number of threads per block + int num_spatial_axes = 0; + size_t last_spatial_iter_id = -1; + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + CHECK(n_splits[iter_id] == 4 || n_splits[iter_id] == 2) + << "The tiling structure is not SSSRRSRS"; + if (n_splits[iter_id] == 4) { + ++num_spatial_axes; + last_spatial_iter_id = iter_id; + } + } + + bool all_below_max_extents; + std::vector num_threads_factor_scheme; + do { + all_below_max_extents = true; + + num_threads_factor_scheme = + SamplePerfectTile(rand_state, num_spatial_axes, num_threads_per_block); + for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (n_splits[iter_id] == 4) { + if (num_threads_factor_scheme[spatial_iter_id] > max_extents[iter_id]) { + all_below_max_extents = false; + } + ++spatial_iter_id; + } + } // for (iter_id ∈ [0, split_steps_info.size())) + } while (!all_below_max_extents); + + // do the looping again and assign the factors + for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (n_splits[iter_id] == 4) { + ret_split_factors[iter_id][1] = num_threads_factor_scheme[spatial_iter_id]; + ++spatial_iter_id; + } + } + + // factor[0] (vthread) + int reg_usage = num_threads_per_block, shmem_usage = 0; + + auto sample_factors = [&](std::function continue_predicate, + std::function max_extent, + std::function factor_to_assign) { + std::vector iter_max_extents; + std::vector factors_to_assign; + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + size_t iter_max_extent = max_extent(iter_id), factor_to_assign; + + std::uniform_int_distribution<> dist(1, iter_max_extent); + factor_to_assign = SampleInt(rand_state, 1, iter_max_extent); + + if (n_splits[iter_id] == 4) { + reg_usage *= factor_to_assign; + } else { + shmem_usage *= factor_to_assign; + } + iter_max_extents.push_back(iter_max_extent); + factors_to_assign.push_back(factor_to_assign); + } + // shuffle the factors + std::vector factors_to_assign_bak = factors_to_assign; + SampleShuffle(rand_state, factors_to_assign.begin(), factors_to_assign.end()); + // make sure that the shuffle is valid + bool valid_shuffle = true; + std::vector::iterator iter_max_extents_it = iter_max_extents.begin(), + factors_to_assign_it = factors_to_assign.begin(); + + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + int iter_max_extent = *iter_max_extents_it; + if (*factors_to_assign_it > iter_max_extent) { + valid_shuffle = false; + } + ++iter_max_extents_it; + ++factors_to_assign_it; + } + if (!valid_shuffle) { + factors_to_assign = std::move(factors_to_assign_bak); + } + // do the actual assignment + factors_to_assign_it = factors_to_assign.begin(); + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (continue_predicate(iter_id)) { + continue; + } + factor_to_assign(iter_id) = *factors_to_assign_it; + ++factors_to_assign_it; + } + }; + + sample_factors( + [&](const size_t iter_id) -> bool { + return (n_splits[iter_id] != 4) || (iter_id != last_spatial_iter_id); + }, + [&](const size_t iter_id) -> int { + size_t max_vthread_extent = std::min( + constraints.max_vthread, max_extents[iter_id] / ret_split_factors[iter_id][1]); + max_vthread_extent = + std::min(constraints.max_vthread, constraints.max_local_memory_per_block / reg_usage); + return max_vthread_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][0]; }); + + // factor[3] (innermost) + sample_factors( + [&](const size_t iter_id) -> bool { + return (n_splits[iter_id] != 4) || (iter_id == last_spatial_iter_id); + }, + [&](const size_t iter_id) -> int { + int max_innermost_extent = + std::min(max_innermost_factor, max_extents[iter_id] / ret_split_factors[iter_id][0] / + ret_split_factors[iter_id][1]); + max_innermost_extent = + std::min(max_innermost_extent, constraints.max_local_memory_per_block / reg_usage); + return max_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][3]; }); + // factor[2] + sample_factors([&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 4); }, + [&](const size_t iter_id) -> size_t { + size_t max_2nd_innermost_extent = + std::min(max_extents[iter_id] / ret_split_factors[iter_id][0] / + ret_split_factors[iter_id][1] / ret_split_factors[iter_id][3], + constraints.max_local_memory_per_block / reg_usage); + return max_2nd_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][2]; }); + + for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { + if (n_splits[iter_id] == 4) { + shmem_usage += ret_split_factors[iter_id][0] * ret_split_factors[iter_id][1] * + ret_split_factors[iter_id][2] * ret_split_factors[iter_id][3]; + } + } + if (shmem_usage > static_cast(constraints.max_shared_memory_per_block / sizeof(float))) { + LOG(FATAL) << "shmem_usage goes out-of-range"; + } + // repeat similar procedure for reduction axes + // rfactor[1] (innermost) + sample_factors( + [&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 2); }, + [&](const size_t iter_id) -> int { + int max_innermost_extent = std::min(max_innermost_factor, max_extents[iter_id]); + max_innermost_extent = std::min(max_innermost_extent, + static_cast(constraints.max_shared_memory_per_block / + sizeof(float) / shmem_usage)); + return max_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][1]; }); + // rfactor[0] + sample_factors([&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 2); }, + [&](const size_t iter_id) -> size_t { + size_t max_2nd_innermost_extent = + std::min(max_extents[iter_id] / ret_split_factors[iter_id][1], + static_cast(constraints.max_shared_memory_per_block / + sizeof(float) / shmem_usage)); + return max_2nd_innermost_extent; + }, + [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][0]; }); + } // if (IsCudaTarget(target)) + return ret_split_factors; +} + +std::vector SamplePerfectTile(tir::ScheduleState self, TRandState* rand_state, const tir::StmtSRef& loop_sref, int n, int max_innermost_factor, Optional>* decision) { @@ -52,8 +498,7 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler::TRandSt result[0] = len; } else { // Case 3. Use fresh new sampling result - std::vector sampled = - Sampler(rand_state).SamplePerfectTile(n, extent, max_innermost_factor); + std::vector sampled = SamplePerfectTile(rand_state, n, extent, max_innermost_factor); result = std::vector(sampled.begin(), sampled.end()); ICHECK_LE(sampled.back(), max_innermost_factor); } @@ -61,7 +506,7 @@ std::vector SamplePerfectTile(tir::ScheduleState self, Sampler::TRandSt return result; } -int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandState* rand_state, +int64_t SampleCategorical(tir::ScheduleState self, TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { int i = -1; @@ -72,14 +517,14 @@ int64_t SampleCategorical(tir::ScheduleState self, Sampler::TRandState* rand_sta CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - i = Sampler(rand_state).MakeMultinomial(AsVector(probs))(); + i = MakeMultinomial(rand_state, AsVector(probs))(); ICHECK(0 <= i && i < n); } *decision = Integer(i); return candidates[i]; } -tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler::TRandState* rand_state, +tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, TRandState* rand_state, const tir::StmtSRef& block_sref, Optional* decision) { // Find all possible compute-at locations Array loop_srefs = tir::CollectComputeLocation(self, block_sref); @@ -112,7 +557,7 @@ tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, Sampler::TRandState } } else { // Sample possible combinations - i = Sampler(rand_state).SampleInt(-2, choices.size()); + i = SampleInt(rand_state, -2, choices.size()); if (i >= 0) { i = choices[i]; } diff --git a/src/tir/schedule/sampler.cc b/src/tir/schedule/sampler.cc deleted file mode 100644 index 2ea9f23a89..0000000000 --- a/src/tir/schedule/sampler.cc +++ /dev/null @@ -1,558 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include "./sampler.h" - -#include -#include - -#include - -namespace tvm { -namespace tir { - -struct PrimeTable { - /*! \brief The table contains prime numbers in [2, kMaxPrime) */ - static constexpr const int kMaxPrime = 65536; - /*! \brief The exact number of prime numbers in the table */ - static constexpr const int kNumPrimes = 6542; - /*! - * \brief For each number in [2, kMaxPrime), the index of its min factor. - * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. - */ - int min_factor_idx[kMaxPrime]; - /*! \brief The prime numbers in [2, kMaxPrime) */ - std::vector primes; - /*! - * \brief The power of each prime number. - * pow_table[i, j] stores the result of pow(prime[i], j + 1) - */ - std::vector> pow_tab; - - /*! \brief Get a global instance of the prime table */ - static const PrimeTable* Global() { - static const PrimeTable table; - return &table; - } - - /*! \brief Constructor, pre-computes all info in the prime table */ - PrimeTable() { - constexpr const int64_t int_max = std::numeric_limits::max(); - // Euler's sieve: prime number in linear time - for (int i = 0; i < kMaxPrime; ++i) { - min_factor_idx[i] = -1; - } - primes.reserve(kNumPrimes); - for (int x = 2; x < kMaxPrime; ++x) { - if (min_factor_idx[x] == -1) { - min_factor_idx[x] = primes.size(); - primes.push_back(x); - } - for (size_t i = 0; i < primes.size(); ++i) { - int factor = primes[i]; - int y = x * factor; - if (y >= kMaxPrime) { - break; - } - min_factor_idx[y] = i; - if (x % factor == 0) { - break; - } - } - } - ICHECK_EQ(static_cast(primes.size()), int(kNumPrimes)); - // Calculate the power table for each prime number - pow_tab.reserve(primes.size()); - for (int prime : primes) { - std::vector tab; - tab.reserve(32); - for (int64_t pow = prime; pow <= int_max; pow *= prime) { - tab.push_back(pow); - } - tab.shrink_to_fit(); - pow_tab.emplace_back(std::move(tab)); - } - } - /*! - * \brief Factorize a number n, and return in a cryptic format - * \param n The number to be factorized - * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] - * For each pair (i, j), we define - * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) - * (primes[i], j) if i != -1 - * Then the factorization is - * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) - */ - std::vector> Factorize(int n) const { - std::vector> result; - result.reserve(16); - int i = 0, n_primes = primes.size(); - // Phase 1: n >= kMaxPrime - for (int j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { - for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { - } - if (j != 0) { - result.emplace_back(i, j); - } - } - // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number - if (n >= kMaxPrime) { - result.emplace_back(-1, n); - return result; - } - // Phase 2: n < kMaxPrime - for (int j; n > 1;) { - int i = min_factor_idx[n]; - for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { - } - result.emplace_back(i, j); - } - return result; - } -}; - -Sampler::TRandState Sampler::ForkSeed() { - // In order for reproducibility, we computer the new seed using sampler's RNG's random state and a - // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. - Sampler::TRandState ret = (this->rand_() * 32767) % 1999999973; - return ret; -} - -// We don't need to check the seed here because it's checked in LCE's seed function. -void Sampler::Seed(Sampler::TRandState seed) { this->rand_.Seed(seed); } - -int Sampler::SampleInt(int min_inclusive, int max_exclusive) { - if (min_inclusive + 1 == max_exclusive) { - return min_inclusive; - } - std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); - return dist(rand_); -} - -std::vector Sampler::SampleInts(int n, int min_inclusive, int max_exclusive) { - std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); - std::vector result; - result.reserve(n); - for (int i = 0; i < n; ++i) { - result.push_back(dist(rand_)); - } - return result; -} - -std::vector Sampler::SampleUniform(int n, double min, double max) { - std::uniform_real_distribution dist(min, max); - std::vector result; - result.reserve(n); - for (int i = 0; i < n; ++i) { - result.push_back(dist(rand_)); - } - return result; -} - -bool Sampler::SampleBernoulli(double p) { - std::bernoulli_distribution dist(p); - return dist(rand_); -} - -std::function Sampler::MakeMultinomial(const std::vector& weights) { - std::vector sums; - sums.reserve(weights.size()); - double sum = 0.0; - for (double w : weights) { - sums.push_back(sum += w); - } - std::uniform_real_distribution dist(0.0, sum); - auto sampler = [this, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { - double p = dist(rand_); - int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); - int n = sums.size(); - CHECK_LE(0, idx); - CHECK_LE(idx, n); - return (idx == n) ? (n - 1) : idx; - }; - return sampler; -} - -std::vector Sampler::SampleWithoutReplacement(int n, int k) { - if (k == 1) { - return {SampleInt(0, n)}; - } - if (k == 2) { - int result0 = SampleInt(0, n); - int result1 = SampleInt(0, n - 1); - if (result1 >= result0) { - result1 += 1; - } - return {result0, result1}; - } - std::vector order(n); - for (int i = 0; i < n; ++i) { - order[i] = i; - } - for (int i = 0; i < k; ++i) { - int j = SampleInt(i, n); - if (i != j) { - std::swap(order[i], order[j]); - } - } - return {order.begin(), order.begin() + k}; -} - -std::vector Sampler::SampleTileFactor(int n, int extent, const std::vector& candidates) { - constexpr int kMaxTrials = 100; - std::uniform_int_distribution<> dist(0, static_cast(candidates.size()) - 1); - std::vector sample(n, -1); - for (int trial = 0; trial < kMaxTrials; ++trial) { - int64_t product = 1; - for (int i = 1; i < n; ++i) { - int value = candidates[dist(rand_)]; - product *= value; - if (product > extent) { - break; - } - sample[i] = value; - } - if (product <= extent) { - sample[0] = (extent + product - 1) / product; - return sample; - } - } - sample[0] = extent; - for (int i = 1; i < n; ++i) { - sample[i] = 1; - } - return sample; -} - -std::vector Sampler::SamplePerfectTile(int n_splits, int extent) { - CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; - CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; - // Handle special case that we can potentially accelerate - if (n_splits == 1) { - return {extent}; - } - if (extent == 1) { - return std::vector(n_splits, 1); - } - // Enumerate each pair (i, j), we define - // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) - // (primes[i], j) if i != -1 - // Then the factorization is - // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) - const PrimeTable* prime_tab = PrimeTable::Global(); - std::vector> factorized = prime_tab->Factorize(extent); - if (n_splits == 2) { - // n_splits = 2, this can be taken special care of, - // because general reservoir sampling can be avoided to accelerate the sampling - int result0 = 1; - int result1 = 1; - for (const std::pair& ij : factorized) { - // Case 1: (a, p) = (j, 1), where j is a prime number - if (ij.first == -1) { - (SampleInt(0, 2) ? result1 : result0) *= ij.second; - continue; - } - // Case 2: (a = primes[i], p = 1) - int p = ij.second; - const int* pow = prime_tab->pow_tab[ij.first].data() - 1; - int x1 = SampleInt(0, p + 1); - int x2 = p - x1; - if (x1 != 0) { - result0 *= pow[x1]; - } - if (x2 != 0) { - result1 *= pow[x2]; - } - } - return {result0, result1}; - } - // Data range: - // 2 <= extent <= 2^31 - 1 - // 3 <= n_splits <= max tiling splits - // 1 <= p <= 31 - std::vector result(n_splits, 1); - for (const std::pair& ij : factorized) { - // Handle special cases to accelerate sampling - // Case 1: (a, p) = (j, 1), where j is a prime number - if (ij.first == -1) { - result[SampleInt(0, n_splits)] *= ij.second; - continue; - } - // Case 2: (a = primes[i], p = 1) - int p = ij.second; - if (p == 1) { - result[SampleInt(0, n_splits)] *= prime_tab->primes[ij.first]; - continue; - } - // The general case. We have to sample uniformly from the solution of: - // x_1 + x_2 + ... + x_{n_splits} = p - // where x_i >= 0 - // Data range: - // 2 <= p <= 31 - // 3 <= n_splits <= max tiling splits - std::vector sampled = SampleWithoutReplacement(p + n_splits - 1, n_splits - 1); - std::sort(sampled.begin(), sampled.end()); - sampled.push_back(p + n_splits - 1); - const int* pow = prime_tab->pow_tab[ij.first].data() - 1; - for (int i = 0, last = -1; i < n_splits; ++i) { - int x = sampled[i] - last - 1; - last = sampled[i]; - if (x != 0) { - result[i] *= pow[x]; - } - } - } - return result; -} - -std::vector Sampler::SamplePerfectTile(int n_splits, int extent, int max_innermost_factor) { - if (max_innermost_factor == -1) { - return this->SamplePerfectTile(n_splits, extent); - } - CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; - std::vector innermost_candidates; - innermost_candidates.reserve(max_innermost_factor); - for (int i = 1; i <= max_innermost_factor; ++i) { - if (extent % i == 0) { - innermost_candidates.push_back(i); - } - } - // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. - // We should do multiple factorization to weight the choices. However, it would lead to slower - // sampling speed. On the other hand, considering potential tricks we might do on the innermost - // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add - // more heuristics in the future - int innermost = innermost_candidates[SampleInt(0, innermost_candidates.size())]; - std::vector result = SamplePerfectTile(n_splits - 1, extent / innermost); - result.push_back(innermost); - return result; -} - -static inline int ExtractInt(const Target& target, const char* name) { - if (Optional v = target->GetAttr(name)) { - return v.value()->value; - } - LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; - throw; -} - -static inline bool IsCudaTarget(const Target& target) { - if (Optional v = target->GetAttr("kind")) { - return v.value() == "cuda"; - } - return false; -} - -std::vector> Sampler::SampleShapeGenericTiles(const std::vector& n_splits, - const std::vector& max_extents, - const Target& target, - int max_innermost_factor) { - std::vector> ret_split_factors; - - if (IsCudaTarget(target)) { - // The following factorization scheme is built under the assumption that: (1) The target is - // CUDA, and (2) The tiling structure is SSSRRSRS. - - // extract all the hardware parameters - const struct HardwareConstraints { - int max_shared_memory_per_block; - int max_local_memory_per_block; - int max_threads_per_block; - int max_innermost_factor; - int max_vthread; - } constraints = {ExtractInt(target, "shared_memory_per_block"), - ExtractInt(target, "registers_per_block"), - ExtractInt(target, "max_threads_per_block"), max_innermost_factor, 8}; - - for (const int n_split : n_splits) { - ret_split_factors.push_back(std::vector(n_split, 1)); - } - - // sample the number of threads per block - const int warp_size = ExtractInt(target, "warp_size"); - int num_threads_per_block = - SampleInt(1, constraints.max_threads_per_block / warp_size) * warp_size; - // find all the possible factors of the number of threads per block - int num_spatial_axes = 0; - size_t last_spatial_iter_id = -1; - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - CHECK(n_splits[iter_id] == 4 || n_splits[iter_id] == 2) - << "The tiling structure is not SSSRRSRS"; - if (n_splits[iter_id] == 4) { - ++num_spatial_axes; - last_spatial_iter_id = iter_id; - } - } - - bool all_below_max_extents; - std::vector num_threads_factor_scheme; - do { - all_below_max_extents = true; - - num_threads_factor_scheme = SamplePerfectTile(num_spatial_axes, num_threads_per_block); - for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (n_splits[iter_id] == 4) { - if (num_threads_factor_scheme[spatial_iter_id] > max_extents[iter_id]) { - all_below_max_extents = false; - } - ++spatial_iter_id; - } - } // for (iter_id ∈ [0, split_steps_info.size())) - } while (!all_below_max_extents); - - // do the looping again and assign the factors - for (size_t iter_id = 0, spatial_iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (n_splits[iter_id] == 4) { - ret_split_factors[iter_id][1] = num_threads_factor_scheme[spatial_iter_id]; - ++spatial_iter_id; - } - } - - // factor[0] (vthread) - int reg_usage = num_threads_per_block, shmem_usage = 0; - - auto sample_factors = [&](std::function continue_predicate, - std::function max_extent, - std::function factor_to_assign) { - std::vector iter_max_extents; - std::vector factors_to_assign; - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - size_t iter_max_extent = max_extent(iter_id), factor_to_assign; - - std::uniform_int_distribution<> dist(1, iter_max_extent); - factor_to_assign = SampleInt(1, iter_max_extent); - - if (n_splits[iter_id] == 4) { - reg_usage *= factor_to_assign; - } else { - shmem_usage *= factor_to_assign; - } - iter_max_extents.push_back(iter_max_extent); - factors_to_assign.push_back(factor_to_assign); - } - // shuffle the factors - std::vector factors_to_assign_bak = factors_to_assign; - Shuffle(factors_to_assign.begin(), factors_to_assign.end()); - // make sure that the shuffle is valid - bool valid_shuffle = true; - std::vector::iterator iter_max_extents_it = iter_max_extents.begin(), - factors_to_assign_it = factors_to_assign.begin(); - - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - int iter_max_extent = *iter_max_extents_it; - if (*factors_to_assign_it > iter_max_extent) { - valid_shuffle = false; - } - ++iter_max_extents_it; - ++factors_to_assign_it; - } - if (!valid_shuffle) { - factors_to_assign = std::move(factors_to_assign_bak); - } - // do the actual assignment - factors_to_assign_it = factors_to_assign.begin(); - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (continue_predicate(iter_id)) { - continue; - } - factor_to_assign(iter_id) = *factors_to_assign_it; - ++factors_to_assign_it; - } - }; - - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 4) || (iter_id != last_spatial_iter_id); - }, - [&](const size_t iter_id) -> int { - size_t max_vthread_extent = std::min( - constraints.max_vthread, max_extents[iter_id] / ret_split_factors[iter_id][1]); - max_vthread_extent = - std::min(constraints.max_vthread, constraints.max_local_memory_per_block / reg_usage); - return max_vthread_extent; - }, - [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][0]; }); - - // factor[3] (innermost) - sample_factors( - [&](const size_t iter_id) -> bool { - return (n_splits[iter_id] != 4) || (iter_id == last_spatial_iter_id); - }, - [&](const size_t iter_id) -> int { - int max_innermost_extent = - std::min(max_innermost_factor, max_extents[iter_id] / ret_split_factors[iter_id][0] / - ret_split_factors[iter_id][1]); - max_innermost_extent = - std::min(max_innermost_extent, constraints.max_local_memory_per_block / reg_usage); - return max_innermost_extent; - }, - [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][3]; }); - // factor[2] - sample_factors([&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 4); }, - [&](const size_t iter_id) -> size_t { - size_t max_2nd_innermost_extent = - std::min(max_extents[iter_id] / ret_split_factors[iter_id][0] / - ret_split_factors[iter_id][1] / ret_split_factors[iter_id][3], - constraints.max_local_memory_per_block / reg_usage); - return max_2nd_innermost_extent; - }, - [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][2]; }); - - for (size_t iter_id = 0; iter_id < n_splits.size(); ++iter_id) { - if (n_splits[iter_id] == 4) { - shmem_usage += ret_split_factors[iter_id][0] * ret_split_factors[iter_id][1] * - ret_split_factors[iter_id][2] * ret_split_factors[iter_id][3]; - } - } - if (shmem_usage > static_cast(constraints.max_shared_memory_per_block / sizeof(float))) { - LOG(FATAL) << "shmem_usage goes out-of-range"; - } - // repeat similar procedure for reduction axes - // rfactor[1] (innermost) - sample_factors( - [&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 2); }, - [&](const size_t iter_id) -> int { - int max_innermost_extent = std::min(max_innermost_factor, max_extents[iter_id]); - max_innermost_extent = std::min(max_innermost_extent, - static_cast(constraints.max_shared_memory_per_block / - sizeof(float) / shmem_usage)); - return max_innermost_extent; - }, - [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][1]; }); - // rfactor[0] - sample_factors([&](const size_t iter_id) -> bool { return (n_splits[iter_id] != 2); }, - [&](const size_t iter_id) -> size_t { - size_t max_2nd_innermost_extent = - std::min(max_extents[iter_id] / ret_split_factors[iter_id][1], - static_cast(constraints.max_shared_memory_per_block / - sizeof(float) / shmem_usage)); - return max_2nd_innermost_extent; - }, - [&](const size_t iter_id) -> int& { return ret_split_factors[iter_id][0]; }); - } // if (IsCudaTarget(target)) - return ret_split_factors; -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/schedule/sampler.h b/src/tir/schedule/sampler.h deleted file mode 100644 index c3dd5704fa..0000000000 --- a/src/tir/schedule/sampler.h +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_TIR_SCHEDULE_SAMPLER_H_ -#define TVM_TIR_SCHEDULE_SAMPLER_H_ - -#include - -#include -#include -#include -#include -namespace tvm { - -class Target; - -namespace tir { - -/*! - * \brief Sampler based on random number generator for sampling in meta schedule. - * \note Typical usage is like Sampler(&random_state).SamplingFunc(...). - */ -class Sampler { - public: - /*! Random state type for random number generator. */ - using TRandState = support::LinearCongruentialEngine::TRandState; - /*! - * \brief Return a random state value that can be used as seed for new samplers. - * \return The random state value to be used as seed for new samplers. - */ - TRandState ForkSeed(); - /*! - * \brief Re-seed the random number generator - * \param seed The random state value given used to re-seed the RNG. - */ - void Seed(TRandState seed); - /*! - * \brief Sample an integer in [min_inclusive, max_exclusive) - * \param min_inclusive The left boundary, inclusive - * \param max_exclusive The right boundary, exclusive - * \return The integer sampled - */ - int SampleInt(int min_inclusive, int max_exclusive); - /*! - * \brief Sample n integers in [min_inclusive, max_exclusive) - * \param min_inclusive The left boundary, inclusive - * \param max_exclusive The right boundary, exclusive - * \return The list of integers sampled - */ - std::vector SampleInts(int n, int min_inclusive, int max_exclusive); - /*! - * \brief Random shuffle from the begin iterator to the end. - * \param begin_it The begin iterator - * \param end_it The end iterator - */ - template - void Shuffle(RandomAccessIterator begin_it, RandomAccessIterator end_it); - /*! - * \brief Sample n tiling factors of the specific extent - * \param n The number of parts the loop is split - * \param extent Length of the loop - * \param candidates The possible tiling factors - * \return A list of length n, the tiling factors sampled - */ - std::vector SampleTileFactor(int n, int extent, const std::vector& candidates); - /*! - * \brief Sample perfect tiling factor of the specific extent - * \param n_splits The number of parts the loop is split - * \param extent Length of the loop - * \return A list of length n_splits, the tiling factors sampled, the product of which strictly - * equals to extent - */ - std::vector SamplePerfectTile(int n_splits, int extent); - /*! - * \brief Sample perfect tiling factor of the specific extent - * \param n_splits The number of parts the loop is split - * \param extent Length of the loop - * \param max_innermost_factor A small number indicating the max length of the innermost loop - * \return A list of length n_splits, the tiling factors sampled, the product of which strictly - * equals to extent - */ - std::vector SamplePerfectTile(int n_splits, int extent, int max_innermost_factor); - /*! - * \brief Sample shape-generic tiling factors that are determined by the hardware constraints. - * \param n_splits The number of parts the loops are split - * \param max_extents Maximum length of the loops - * \param is_spatial Whether each loop is a spatial axis or not - * \param target Hardware target - * \param max_innermost_factor A small number indicating the max length of the innermost loop - * \return A list of list of length n_splits, the tiling factors sampled, all satisfying the - * maximum extents and the hardware constraints - */ - std::vector> SampleShapeGenericTiles(const std::vector& n_splits, - const std::vector& max_extents, - const Target& target, - int max_innermost_factor); - /*! - * \brief Sample n floats uniformly in [min, max) - * \param min The left boundary - * \param max The right boundary - * \return The list of floats sampled - */ - std::vector SampleUniform(int n, double min, double max); - /*! - * \brief Sample from a Bernoulli distribution - * \param p Parameter in the Bernoulli distribution - * \return return true with probability p, and false with probability (1 - p) - */ - bool SampleBernoulli(double p); - /*! - * \brief Create a multinomial sampler based on the specific weights - * \param weights The weights, event probabilities - * \return The multinomial sampler - */ - std::function MakeMultinomial(const std::vector& weights); - /*! - * \brief Classic sampling without replacement - * \param n The population size - * \param k The number of samples to be drawn from the population - * \return A list of indices, samples drawn, unsorted and index starting from 0 - */ - std::vector SampleWithoutReplacement(int n, int k); - /*! \brief The default constructor function for Sampler */ - Sampler() = default; - /*! - * \brief Constructor. Construct a sampler with a given random state pointer for its RNG. - * \param random_state The given pointer to random state used to construct the RNG. - * \note The random state is neither initialized not modified by this constructor. - */ - explicit Sampler(TRandState* random_state) : rand_(random_state) {} - - private: - /*! \brief The random number generator for sampling. */ - support::LinearCongruentialEngine rand_; -}; - -template -void Sampler::Shuffle(RandomAccessIterator begin_it, RandomAccessIterator end_it) { - std::shuffle(begin_it, end_it, rand_); -} - -} // namespace tir -} // namespace tvm - -#endif // TVM_TIR_SCHEDULE_SAMPLER_H_ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index ce66e75919..57e3cd63b7 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -21,26 +21,26 @@ namespace tvm { namespace tir { -Schedule Schedule::Traced(IRModule mod, Sampler::TRandState seed, int debug_mode, +Schedule Schedule::Traced(IRModule mod, tir::TRandState seed, int debug_mode, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); n->error_render_level_ = error_render_level; if (seed == -1) seed = std::random_device()(); - Sampler(&n->rand_state_).Seed(seed); + tir::RandEngine(&n->rand_state_).Seed(seed); n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); return Schedule(std::move(n)); } -Schedule TracedScheduleNode::Copy(Sampler::TRandState new_seed) const { +Schedule TracedScheduleNode::Copy(tir::TRandState new_seed) const { ObjectPtr n = make_object(); ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->error_render_level_ = this->error_render_level_; n->analyzer_ = std::make_unique(); if (new_seed == -1) new_seed = std::random_device()(); - Sampler(&n->rand_state_).Seed(new_seed); + tir::RandEngine(&n->rand_state_).Seed(new_seed); n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); return Schedule(std::move(n)); } @@ -453,7 +453,7 @@ void TracedScheduleNode::InlineArgument(int i, const String& func_name) { TVM_REGISTER_NODE_TYPE(TracedScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") - .set_body_typed([](IRModule mod, Sampler::TRandState seed, int debug_mode, + .set_body_typed([](IRModule mod, tir::TRandState seed, int debug_mode, int error_render_level) -> Schedule { return Schedule::Traced(mod, seed, debug_mode, static_cast(error_render_level)); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index beacce0ddc..a0e7c714c6 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,7 +47,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: Optional trace() const final { return trace_; } - Schedule Copy(Sampler::TRandState new_seed = -1) const final; + Schedule Copy(tir::TRandState new_seed = -1) const final; public: /******** Schedule: Sampling ********/ diff --git a/tests/cpp/meta_schedule_test.cc b/tests/cpp/meta_schedule_test.cc index 4b68da7003..a01cbfef14 100644 --- a/tests/cpp/meta_schedule_test.cc +++ b/tests/cpp/meta_schedule_test.cc @@ -20,13 +20,13 @@ #include #include -#include "../../../src/tir/schedule/sampler.h" +#include "../../../src/tir/schedule/primitive.h" -TEST(Simplify, Sampler) { +TEST(Simplify, Sampling) { int64_t current = 100; for (int i = 0; i < 10; i++) { - tvm::tir::Sampler(¤t).SampleInt(0, 100); - tvm::tir::Sampler(¤t).SampleUniform(3, -1, 0); + tvm::tir::SampleInt(¤t, 0, 100); + tvm::tir::SampleUniform(¤t, 3, -1, 0); } } From 6f3aed92eb8c3e5e5ce36b90a74702be1bf688a7 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 10 Aug 2021 18:36:56 -0700 Subject: [PATCH 14/23] Fix rebase error. --- src/meta_schedule/space/post_order_apply.cc | 2 +- src/meta_schedule/space/schedule_fn.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/meta_schedule/space/post_order_apply.cc b/src/meta_schedule/space/post_order_apply.cc index 2ea4af08e9..35fc71de21 100644 --- a/src/meta_schedule/space/post_order_apply.cc +++ b/src/meta_schedule/space/post_order_apply.cc @@ -99,7 +99,7 @@ PostOrderApply::PostOrderApply(Array stages, Array postpro bool PostOrderApplyNode::Postprocess(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state) { - sch->EnterPostProc(); + sch->EnterPostproc(); for (const Postproc& postproc : postprocs) { if (!postproc->Apply(task, sch, rand_state)) { return false; diff --git a/src/meta_schedule/space/schedule_fn.cc b/src/meta_schedule/space/schedule_fn.cc index de179a9804..9bce1872f4 100644 --- a/src/meta_schedule/space/schedule_fn.cc +++ b/src/meta_schedule/space/schedule_fn.cc @@ -97,7 +97,7 @@ ScheduleFn::ScheduleFn(PackedFunc sch_fn, Array postprocs) { bool ScheduleFnNode::Postprocess(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state) { - sch->EnterPostProc(); + sch->EnterPostproc(); for (const Postproc& postproc : postprocs) { if (!postproc->Apply(task, sch, rand_state)) { return false; From caa109e3adef0da4be7d48a1dcd23fdc510b1012 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 16:48:15 -0700 Subject: [PATCH 15/23] Fix rpc tests. --- python/tvm/meta_schedule/utils.py | 12 ++--- .../cost_model/rand_cost_model.cc | 2 +- .../meta_schedule/test_integration_cuda.py | 2 +- ...test_meta_schedule_bsr_sparse_dense_cpu.py | 48 +++++++++---------- ...st_meta_schedule_layout_rewrite_network.py | 10 ++-- .../meta_schedule/test_resnet_end_to_end.py | 6 +-- .../test_resnet_end_to_end_cuda.py | 4 +- 7 files changed, 41 insertions(+), 43 deletions(-) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 294098cf50..83a46e6cc0 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -46,7 +46,7 @@ def make_error_msg() -> str: - """ Get the error message from traceback. """ + """Get the error message from traceback.""" error_msg = str(traceback.format_exc()) if len(error_msg) > MAX_ERROR_MSG_LEN: error_msg = ( @@ -408,7 +408,7 @@ def local_builder_worker( timeout: int, verbose: int, ) -> BuildResult.TYPE: - """ Local worker for ProgramBuilder """ + """Local worker for ProgramBuilder""" # deal with build_func build_func = { "tar": build_func_tar.tar, # export to tar @@ -603,7 +603,7 @@ def rpc_runner_worker( f_create_args: Callable[[Device], List[NDArray]], verbose: int, ) -> MeasureResult.TYPE: - """ RPC worker for ProgramRunner """ + """RPC worker for ProgramRunner""" measure_input = measure_inputs[index] build_result = build_results[index] @@ -653,13 +653,13 @@ def timed_func(): else: rpc_eval_repeat = 1 if f_create_args is not None: - args_set = [f_create_args(ctx) for _ in range(rpc_eval_repeat)] + args_set = [f_create_args(dev) for _ in range(rpc_eval_repeat)] else: args_set = [ - realize_arguments(remote, ctx, measure_input.sch.mod["main"]) + realize_arguments(remote, dev, measure_input.sch.mod["main"]) for _ in range(rpc_eval_repeat) ] - ctx.sync() + dev.sync() costs = sum([time_f(*args).results for args in args_set], ()) # clean up remote files remote.remove(build_result.filename) diff --git a/src/meta_schedule/cost_model/rand_cost_model.cc b/src/meta_schedule/cost_model/rand_cost_model.cc index 0afc722f7d..2389b7f1c9 100644 --- a/src/meta_schedule/cost_model/rand_cost_model.cc +++ b/src/meta_schedule/cost_model/rand_cost_model.cc @@ -27,7 +27,7 @@ namespace meta_schedule { /*! \brief The cost model returning random value for all predictions */ class RandCostModelNode : public CostModelNode { public: - /*! \brief A random state for sampler to generate random numbers */ + /*! \brief A random state for sampling functions to generate random numbers */ tir::TRandState rand_state; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/tests/python/meta_schedule/test_integration_cuda.py b/tests/python/meta_schedule/test_integration_cuda.py index 67c5056f4e..361a7a52cd 100644 --- a/tests/python/meta_schedule/test_integration_cuda.py +++ b/tests/python/meta_schedule/test_integration_cuda.py @@ -30,7 +30,7 @@ logging.getLogger("meta_schedule").setLevel(logging.DEBUG) RPC_KEY = "rtx-3070" -TARGET = tvm.target.Target("nvidia/geforce-rtx-3070") +TARGET = tvm.target.Target("nvidia/geforce-rtx-2080-ti") TARGET_HOST = tvm.target.Target("llvm") SPACE = ms.space.PostOrderApply( stages=[ diff --git a/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py b/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py index 8866d0d6ad..55d644d4c8 100644 --- a/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py +++ b/tests/python/meta_schedule/test_meta_schedule_bsr_sparse_dense_cpu.py @@ -143,7 +143,7 @@ def test_sparse_dense(): print("M =", M, "N =", N, "K =", K, "BS_R =", BS_R, "BS_C = ", BS_C) def check_device(device): - ctx = tvm.context(device, 0) + ctx = tvm.device(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) return @@ -154,12 +154,12 @@ def check_device(device): Y = fcompute(X, W_data, W_indices, W_indptr) s = fschedule([Y]) func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), device=ctx) func( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), + tvm.nd.array(X_np, device=ctx), + tvm.nd.array(W_sp_np.data, device=ctx), + tvm.nd.array(W_sp_np.indices, device=ctx), + tvm.nd.array(W_sp_np.indptr, device=ctx), Y_tvm, ) tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) @@ -168,10 +168,10 @@ def check_device(device): "sparse dense te schedule: %f ms" % ( evaluator( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), + tvm.nd.array(X_np, device=ctx), + tvm.nd.array(W_sp_np.data, device=ctx), + tvm.nd.array(W_sp_np.indices, device=ctx), + tvm.nd.array(W_sp_np.indptr, device=ctx), Y_tvm, ).mean * 1e3 @@ -187,23 +187,23 @@ def check_device(device): func = func.specialize(N_blocks, N // BS_R).remove_const_param(N_blocks) def f_create_args(ctx): - X = tvm.nd.array(X_np, ctx=ctx) - W_data = tvm.nd.array(W_sp_np.data, ctx=ctx) - W_indices = tvm.nd.array(W_sp_np.indices, ctx=ctx) - W_indptr = tvm.nd.array(W_sp_np.indptr, ctx=ctx) - Y = tvm.nd.array(Y_np, ctx=ctx) + X = tvm.nd.array(X_np, device=ctx) + W_data = tvm.nd.array(W_sp_np.data, device=ctx) + W_indices = tvm.nd.array(W_sp_np.indices, device=ctx) + W_indptr = tvm.nd.array(W_sp_np.indptr, device=ctx) + Y = tvm.nd.array(Y_np, device=ctx) return [X, W_data, W_indices, W_indptr, Y] sch = meta_schedule_sparse_dense_llvm(func, f_create_args) func = sch.mod func = tvm.build(func) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), device=ctx) func( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), + tvm.nd.array(X_np, device=ctx), + tvm.nd.array(W_sp_np.data, device=ctx), + tvm.nd.array(W_sp_np.indices, device=ctx), + tvm.nd.array(W_sp_np.indptr, device=ctx), Y_tvm, ) tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) @@ -212,10 +212,10 @@ def f_create_args(ctx): "sparse dense auto tir schedule: %f ms" % ( evaluator( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), + tvm.nd.array(X_np, device=ctx), + tvm.nd.array(W_sp_np.data, device=ctx), + tvm.nd.array(W_sp_np.indices, device=ctx), + tvm.nd.array(W_sp_np.indptr, device=ctx), Y_tvm, ).mean * 1e3 diff --git a/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py b/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py index 78061c722d..eee39ec024 100644 --- a/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py +++ b/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py @@ -99,9 +99,9 @@ def get_relay_batchmm(batch=4, m=128, n=128, k=128): return mod, data, weight -RPC_KEY = "raspi4b-aarch64" -TARGET = tvm.target.Target("raspberry-pi/4b-64") -TARGET_HOST = tvm.target.Target("raspberry-pi/4b-64") +RPC_KEY = "test" +TARGET = tvm.target.Target("llvm") +TARGET_HOST = tvm.target.Target("llvm") SPACE = ms.space.PostOrderApply( stages=[ ms.rule.inline_pure_spatial(strict_mode=True), @@ -200,9 +200,7 @@ def run_module(lib, use_arm): lib.export_library(tmp.relpath(filename)) # Upload module to device print("Upload...") - remote = auto_scheduler.utils.request_remote( - RPC_KEY, "172.16.2.241", 4445, timeout=10000 - ) + remote = auto_scheduler.utils.request_remote(RPC_KEY, "localhost", 4728, timeout=10000) remote.upload(tmp.relpath(filename)) rlib = remote.load_module(filename) diff --git a/tests/python/meta_schedule/test_resnet_end_to_end.py b/tests/python/meta_schedule/test_resnet_end_to_end.py index c0ada49b01..0b1b97b19e 100644 --- a/tests/python/meta_schedule/test_resnet_end_to_end.py +++ b/tests/python/meta_schedule/test_resnet_end_to_end.py @@ -96,13 +96,13 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): return mod, params, input_shape, output_shape -RPC_KEY = "raspi4b-aarch64" +RPC_KEY = "local" network = "resnet-50" batch_size = 1 layout = "NHWC" -target = tvm.target.Target("raspberry-pi/4b-64") +target = tvm.target.Target("llvm") dtype = "float32" -TARGET_HOST = tvm.target.Target("raspberry-pi/4b-64") +TARGET_HOST = tvm.target.Target("llvm") SPACE = ms.space.PostOrderApply( stages=[ ms.rule.inline_pure_spatial(strict_mode=True), diff --git a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py index 71a0f649e4..d9133036fe 100644 --- a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py +++ b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py @@ -96,11 +96,11 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): return mod, params, input_shape, output_shape -RPC_KEY = "rtx-3080" +RPC_KEY = "rtx-3070" network = "resnet-50" batch_size = 1 layout = "NHWC" -target = tvm.target.Target("nvidia/geforce-rtx-3080") +target = tvm.target.Target("nvidia/geforce-rtx-2080-ti") dtype = "float32" TARGET_HOST = tvm.target.Target("llvm") SPACE = ms.space.PostOrderApply( From 0fffeece0ca1532ee7bbee9722c991d49077eb81 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 17:06:03 -0700 Subject: [PATCH 16/23] Remove sampler comments. --- include/tvm/support/random_engine.h | 2 +- src/meta_schedule/autotune.h | 2 -- .../cost_model/rand_cost_model.cc | 4 +--- src/meta_schedule/search.cc | 2 +- src/meta_schedule/search.h | 2 +- src/meta_schedule/space/post_order_apply.cc | 2 +- src/meta_schedule/space/postproc.h | 2 +- src/meta_schedule/space/schedule_fn.cc | 2 +- src/meta_schedule/strategy/evolutionary.cc | 22 +++++++++---------- src/meta_schedule/strategy/mutator.h | 2 +- src/tir/schedule/primitive.h | 2 +- src/tir/schedule/primitive/sampling.cc | 2 +- 12 files changed, 21 insertions(+), 25 deletions(-) diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 0889a383d6..9713b69d42 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -44,7 +44,7 @@ namespace support { class LinearCongruentialEngine { public: /*! - * \brief The result type is defined as int64_t here for meta_schedule sampler usage. + * \brief The result type is defined as int64_t here to avoid overflow. * \note The type name is not in Google style because it is used in STL's distribution inferface. */ using result_type = uint64_t; diff --git a/src/meta_schedule/autotune.h b/src/meta_schedule/autotune.h index e61027a924..92ae034a53 100644 --- a/src/meta_schedule/autotune.h +++ b/src/meta_schedule/autotune.h @@ -57,7 +57,6 @@ class TuneContextNode : public runtime::Object { v->Visit("postprocs", &postprocs); v->Visit("measure_callbacks", &measure_callbacks); v->Visit("num_threads", &num_threads); - // `sampler` is not visited } void Init(Optional seed = NullOpt); @@ -94,7 +93,6 @@ class TuneContext : public runtime::ObjectRef { n->postprocs = postprocs; n->measure_callbacks = measure_callbacks; n->num_threads = num_threads; - // `n->sampler` is not initialized data_ = std::move(n); (*this)->Init(seed); } diff --git a/src/meta_schedule/cost_model/rand_cost_model.cc b/src/meta_schedule/cost_model/rand_cost_model.cc index 2389b7f1c9..dba3c1b93d 100644 --- a/src/meta_schedule/cost_model/rand_cost_model.cc +++ b/src/meta_schedule/cost_model/rand_cost_model.cc @@ -30,9 +30,7 @@ class RandCostModelNode : public CostModelNode { /*! \brief A random state for sampling functions to generate random numbers */ tir::TRandState rand_state; - void VisitAttrs(tvm::AttrVisitor* v) { - // sampler is not visited - } + void VisitAttrs(tvm::AttrVisitor* v) {} /*! * \brief Update the cost model according to new measurement results (training data). diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index a872e2dd75..e13e3bf274 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -102,7 +102,7 @@ struct Internal { * \brief Apply postprocessors onto the schedule * \param space The search space * \param sch The schedule to be postprocessed - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling * \return Whether postprocessing has succeeded * \sa SearchSpaceNode::Postprocess */ diff --git a/src/meta_schedule/search.h b/src/meta_schedule/search.h index 3bc6571d8f..d184a0b48a 100644 --- a/src/meta_schedule/search.h +++ b/src/meta_schedule/search.h @@ -102,7 +102,7 @@ class SearchSpaceNode : public runtime::Object { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling */ virtual bool Postprocess(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state) = 0; diff --git a/src/meta_schedule/space/post_order_apply.cc b/src/meta_schedule/space/post_order_apply.cc index 35fc71de21..58a05658e8 100644 --- a/src/meta_schedule/space/post_order_apply.cc +++ b/src/meta_schedule/space/post_order_apply.cc @@ -49,7 +49,7 @@ class PostOrderApplyNode : public SearchSpaceNode { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling */ bool Postprocess(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state) override; diff --git a/src/meta_schedule/space/postproc.h b/src/meta_schedule/space/postproc.h index d9673fa48c..770559820f 100644 --- a/src/meta_schedule/space/postproc.h +++ b/src/meta_schedule/space/postproc.h @@ -44,7 +44,7 @@ class PostprocNode : public Object { /*! * \brief Apply the postprocessor * \param sch The schedule to be processed - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling * \return If the post-processing succeeds */ bool Apply(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state); diff --git a/src/meta_schedule/space/schedule_fn.cc b/src/meta_schedule/space/schedule_fn.cc index 9bce1872f4..c009f1a09c 100644 --- a/src/meta_schedule/space/schedule_fn.cc +++ b/src/meta_schedule/space/schedule_fn.cc @@ -47,7 +47,7 @@ class ScheduleFnNode : public SearchSpaceNode { * \brief Apply postprocessors onto the schedule * \param task The search task * \param sch The schedule to be postprocessed - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling */ bool Postprocess(const SearchTask& task, const Schedule& sch, tir::TRandState* rand_state) override; diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index eafceb6695..53dbdea59b 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -134,7 +134,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \param task The search task * \param space The search space * \param measurer The measurer that builds, runs and profiles sampled programs - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling * \param verbose Whether or not in verbose mode * \return The best schedule found, NullOpt if no valid schedule is found */ @@ -151,7 +151,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \param support The support to be sampled from * \param task The search task * \param space The search space - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling * \return The generated samples, all of which are not post-processed */ Array SampleInitPopulation(const Array& support, const SearchTask& task, @@ -162,7 +162,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \param inits The initial population * \param task The search task * \param space The search space - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling * \return An array of schedules, the sampling result */ Array EvolveWithCostModel(const Array& inits, const SearchTask& task, @@ -174,7 +174,7 @@ class EvolutionaryNode : public SearchStrategyNode { * \param bests The best populations according to the cost model when picking top states * \param task The search task * \param space The search space - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling * \return A list of schedules, result of epsilon-greedy sampling */ Array PickWithEpsGreedy(const Array& inits, const Array& bests, @@ -200,12 +200,12 @@ class EvolutionaryNode : public SearchStrategyNode { friend class Evolutionary; /*! - * \brief Fork a sampler into `n` samplers - * \param n The number of samplers to be forked - * \param rand_state The sampler's random state + * \brief Fork a random state into `n` random states + * \param n The number of random states to be forked + * \param rand_state The random state for sampling * \return A list of random states, the result of forking */ - static std::vector ForkSamplers(int n, tir::TRandState* rand_state) { + static std::vector ForkRandStates(int n, tir::TRandState* rand_state) { std::vector result; result.reserve(n); for (int i = 0; i < n; ++i) { @@ -243,7 +243,7 @@ class EvolutionaryNode : public SearchStrategyNode { /*! * \brief Create a sampler function that picks mutators according to the mass function - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling * \return The sampler created */ static std::function()> MakeMutatorSampler( @@ -477,7 +477,7 @@ Array EvolutionaryNode::SampleInitPopulation(const Array& suppo results.reserve(this->population); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_rand_states = ForkSamplers(num_threads, global_rand_state); + std::vector thread_rand_states = ForkRandStates(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); // Pick measured states int num_measured = this->population * this->init_measured_ratio; @@ -550,7 +550,7 @@ Array EvolutionaryNode::EvolveWithCostModel(const Array& inits, SizedHeap heap(this->num_measures_per_iteration); // Threading RNG int num_threads = std::thread::hardware_concurrency(); - std::vector thread_rand_states = ForkSamplers(num_threads, global_rand_state); + std::vector thread_rand_states = ForkRandStates(num_threads, global_rand_state); std::vector thread_workloads = ForkWorkload(num_threads, task->workload); std::vector> thread_trace_samplers(num_threads); std::vector()>> thread_mutator_samplers(num_threads); diff --git a/src/meta_schedule/strategy/mutator.h b/src/meta_schedule/strategy/mutator.h index dae5cd30b3..286462c47c 100644 --- a/src/meta_schedule/strategy/mutator.h +++ b/src/meta_schedule/strategy/mutator.h @@ -44,7 +44,7 @@ class MutatorNode : public Object { * \brief Mutate the schedule by applying the mutation * \param task The search task * \param trace The trace to be mutated - * \param rand_state The sampler's random state + * \param rand_state The random state for sampling * \return The new schedule after mutation, NullOpt if mutation fails */ Optional Apply(const SearchTask& task, const Trace& trace, tir::TRandState* rand_state); diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 561812e81e..99db58c5f8 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -131,7 +131,7 @@ struct PrimeTable { /******** Schedule: Sampling ********/ -/*! \brief Return a seed that can be used to create a new sampler */ +/*! \brief Return a seed that can be used as a new random state. */ TRandState ForkSeed(TRandState* rand_state); /*! * \brief Sample an integer in [min_inclusive, max_exclusive) diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 7057527ef9..ba601382f6 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -25,7 +25,7 @@ namespace tvm { namespace tir { TRandState ForkSeed(TRandState* rand_state) { - // In order for reproducibility, we computer the new seed using sampler's RNG's random state and a + // In order for reproducibility, we computer the new seed using RNG's random state and a // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. TRandState ret = (RandEngine(rand_state)() * 32767) % 1999999973; return ret; From 015fda3b100e694c594caf813befef8329252db4 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 17:15:46 -0700 Subject: [PATCH 17/23] Fix seeding process. --- src/meta_schedule/autotune.cc | 2 +- src/meta_schedule/search.cc | 40 ++++++++++++++-------- src/meta_schedule/space/postproc.cc | 8 +++-- src/meta_schedule/strategy/evolutionary.cc | 24 ++++++++----- src/meta_schedule/strategy/mutator.cc | 8 +++-- 5 files changed, 51 insertions(+), 31 deletions(-) diff --git a/src/meta_schedule/autotune.cc b/src/meta_schedule/autotune.cc index 1e028b15ce..de7b0edd23 100644 --- a/src/meta_schedule/autotune.cc +++ b/src/meta_schedule/autotune.cc @@ -24,7 +24,7 @@ namespace tvm { namespace meta_schedule { void TuneContextNode::Init(Optional seed) { - if (seed.defined() && seed.value() != -1) { + if (seed.defined() && (seed.value() != -1)) { tir::RandEngine(&this->rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&this->rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index e13e3bf274..e1adc04dd5 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -58,9 +58,11 @@ SearchTask::SearchTask(tir::PrimFunc workload, String task_name, Target target, */ TVM_DLL Optional AutoTune(SearchTask task, SearchSpace space, SearchStrategy strategy, ProgramMeasurer measurer, Optional seed, int verbose) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } if (verbose) { @@ -108,9 +110,11 @@ struct Internal { */ static bool SearchSpacePostprocess(SearchSpace space, SearchTask task, Schedule sch, Optional seed) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return space->Postprocess(task, sch, &rand_state); } @@ -123,9 +127,11 @@ struct Internal { */ static Schedule SearchSpaceSampleSchedule(SearchSpace space, SearchTask task, Optional seed) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return space->SampleSchedule(task, &rand_state); } @@ -139,9 +145,11 @@ struct Internal { */ static Array SearchSpaceGetSupport(SearchSpace space, SearchTask task, Optional seed) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return space->GetSupport(task, &rand_state); } @@ -157,9 +165,11 @@ struct Internal { static Optional SearchStrategySearch(SearchStrategy strategy, SearchTask task, SearchSpace space, ProgramMeasurer measurer, Optional seed, int verbose) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return strategy->Search(task, space, measurer, &rand_state, verbose); } diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index 5c32cf1da2..b957b8ce90 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -1118,9 +1118,11 @@ struct Internal { * \sa PostProcNode::Apply */ static bool Apply(Postproc self, SearchTask task, Schedule sch, Optional seed) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return self->Apply(task, sch, &rand_state); } diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index 53dbdea59b..757fdcaa92 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -779,9 +779,11 @@ struct Internal { static Array SampleInitPopulation(Evolutionary self, Array support, SearchTask task, SearchSpace space, Optional seed) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return self->SampleInitPopulation(support, task, space, &rand_state); } @@ -797,9 +799,11 @@ struct Internal { */ static Array EvolveWithCostModel(Evolutionary self, Array inits, SearchTask task, SearchSpace space, Optional seed) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return self->EvolveWithCostModel(inits, task, space, &rand_state); } @@ -815,9 +819,11 @@ struct Internal { static Array PickWithEpsGreedy(Evolutionary self, Array inits, Array bests, SearchTask task, SearchSpace space, Optional seed) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return self->PickWithEpsGreedy(inits, bests, task, space, &rand_state); } diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index c6d39b452e..1adf5df9e3 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -481,9 +481,11 @@ struct Internal { */ static Optional Apply(Mutator mutator, SearchTask task, Trace trace, Optional seed) { - tir::TRandState rand_state = std::random_device()(); - if (seed.defined()) { - tir::RandEngine(&rand_state).Seed(seed.value()); + tir::TRandState rand_state; + if (seed.defined() && (seed.value() != -1)) { + tir::RandEngine(&rand_state).Seed(seed.value()->value); + } else { + tir::RandEngine(&rand_state).Seed(std::random_device()()); } return mutator->Apply(task, trace, &rand_state); } From 5974a1ce14e8cfaae8274c522f717cc0eae995cf Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 17:24:18 -0700 Subject: [PATCH 18/23] Fix minor bug. --- src/meta_schedule/autotune.cc | 2 +- src/meta_schedule/search.cc | 10 +++++----- src/meta_schedule/space/postproc.cc | 2 +- src/meta_schedule/strategy/evolutionary.cc | 6 +++--- src/meta_schedule/strategy/mutator.cc | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/meta_schedule/autotune.cc b/src/meta_schedule/autotune.cc index de7b0edd23..fc5b156d4e 100644 --- a/src/meta_schedule/autotune.cc +++ b/src/meta_schedule/autotune.cc @@ -24,7 +24,7 @@ namespace tvm { namespace meta_schedule { void TuneContextNode::Init(Optional seed) { - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&this->rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&this->rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index e1adc04dd5..de966708d3 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -59,7 +59,7 @@ SearchTask::SearchTask(tir::PrimFunc workload, String task_name, Target target, TVM_DLL Optional AutoTune(SearchTask task, SearchSpace space, SearchStrategy strategy, ProgramMeasurer measurer, Optional seed, int verbose) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -111,7 +111,7 @@ struct Internal { static bool SearchSpacePostprocess(SearchSpace space, SearchTask task, Schedule sch, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -128,7 +128,7 @@ struct Internal { static Schedule SearchSpaceSampleSchedule(SearchSpace space, SearchTask task, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -146,7 +146,7 @@ struct Internal { static Array SearchSpaceGetSupport(SearchSpace space, SearchTask task, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -166,7 +166,7 @@ struct Internal { SearchSpace space, ProgramMeasurer measurer, Optional seed, int verbose) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index b957b8ce90..2c0634ff66 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -1119,7 +1119,7 @@ struct Internal { */ static bool Apply(Postproc self, SearchTask task, Schedule sch, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index 757fdcaa92..6f71b7c117 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -780,7 +780,7 @@ struct Internal { SearchTask task, SearchSpace space, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -800,7 +800,7 @@ struct Internal { static Array EvolveWithCostModel(Evolutionary self, Array inits, SearchTask task, SearchSpace space, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -820,7 +820,7 @@ struct Internal { SearchTask task, SearchSpace space, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index 1adf5df9e3..4a794cfcb8 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -482,7 +482,7 @@ struct Internal { static Optional Apply(Mutator mutator, SearchTask task, Trace trace, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && (seed.value() != -1)) { + if (seed.defined() && seed.value()->value > 0) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); From 89fba270781cfc2aa05b7e7b8b92843a9395213d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 17:29:38 -0700 Subject: [PATCH 19/23] Monir fix. --- .../python/meta_schedule/test_integration_cuda.py | 2 +- .../meta_schedule/test_resnet_end_to_end.py | 2 +- .../meta_schedule/test_resnet_end_to_end_cuda.py | 15 ++++++++++----- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/python/meta_schedule/test_integration_cuda.py b/tests/python/meta_schedule/test_integration_cuda.py index 361a7a52cd..526b1339f9 100644 --- a/tests/python/meta_schedule/test_integration_cuda.py +++ b/tests/python/meta_schedule/test_integration_cuda.py @@ -29,7 +29,7 @@ logging.basicConfig() logging.getLogger("meta_schedule").setLevel(logging.DEBUG) -RPC_KEY = "rtx-3070" +RPC_KEY = "rtx-2080ti" TARGET = tvm.target.Target("nvidia/geforce-rtx-2080-ti") TARGET_HOST = tvm.target.Target("llvm") SPACE = ms.space.PostOrderApply( diff --git a/tests/python/meta_schedule/test_resnet_end_to_end.py b/tests/python/meta_schedule/test_resnet_end_to_end.py index 0b1b97b19e..c0ad78838c 100644 --- a/tests/python/meta_schedule/test_resnet_end_to_end.py +++ b/tests/python/meta_schedule/test_resnet_end_to_end.py @@ -96,7 +96,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): return mod, params, input_shape, output_shape -RPC_KEY = "local" +RPC_KEY = "test" network = "resnet-50" batch_size = 1 layout = "NHWC" diff --git a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py index d9133036fe..6cc9c8fb77 100644 --- a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py +++ b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py @@ -96,7 +96,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): return mod, params, input_shape, output_shape -RPC_KEY = "rtx-3070" +RPC_KEY = "rtx-2080ti" network = "resnet-50" batch_size = 1 layout = "NHWC" @@ -177,12 +177,14 @@ def test_end_to_end_resnet(log): measure_callbacks=[ ms.RecordToFile(), ] - ) + ), ) with ms.ApplyHistoryBest(log, SPACE): - with tvm.transform.PassContext(opt_level=3, config={"relay.with_tir_schedule": True, - "relay.backend.use_meta_schedule": True}): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.with_tir_schedule": True, "relay.backend.use_meta_schedule": True}, + ): lib = relay.build_module.build(mod, target, params=params) def run_module(lib): @@ -195,7 +197,10 @@ def run_module(lib): print("Evaluate inference time cost...") ftimer = module.module.time_evaluator("run", ctx, repeat=3, min_repeat_ms=500) prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond - print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) + print( + "Mean inference time (std dev): %.2f ms (%.2f ms)" + % (np.mean(prof_res), np.std(prof_res)) + ) module.run() return module.get_output(0) From d9f8c16acec88ca4a6ef5a5782ff03465ab295d2 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 17:47:39 -0700 Subject: [PATCH 20/23] Make seed value consistent. --- src/meta_schedule/autotune.cc | 2 +- src/meta_schedule/search.cc | 10 +++++----- src/meta_schedule/space/postproc.cc | 2 +- src/meta_schedule/strategy/evolutionary.cc | 6 +++--- src/meta_schedule/strategy/mutator.cc | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/meta_schedule/autotune.cc b/src/meta_schedule/autotune.cc index fc5b156d4e..3ad73fd99f 100644 --- a/src/meta_schedule/autotune.cc +++ b/src/meta_schedule/autotune.cc @@ -24,7 +24,7 @@ namespace tvm { namespace meta_schedule { void TuneContextNode::Init(Optional seed) { - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&this->rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&this->rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/search.cc b/src/meta_schedule/search.cc index de966708d3..e4b161049c 100644 --- a/src/meta_schedule/search.cc +++ b/src/meta_schedule/search.cc @@ -59,7 +59,7 @@ SearchTask::SearchTask(tir::PrimFunc workload, String task_name, Target target, TVM_DLL Optional AutoTune(SearchTask task, SearchSpace space, SearchStrategy strategy, ProgramMeasurer measurer, Optional seed, int verbose) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -111,7 +111,7 @@ struct Internal { static bool SearchSpacePostprocess(SearchSpace space, SearchTask task, Schedule sch, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -128,7 +128,7 @@ struct Internal { static Schedule SearchSpaceSampleSchedule(SearchSpace space, SearchTask task, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -146,7 +146,7 @@ struct Internal { static Array SearchSpaceGetSupport(SearchSpace space, SearchTask task, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -166,7 +166,7 @@ struct Internal { SearchSpace space, ProgramMeasurer measurer, Optional seed, int verbose) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/space/postproc.cc b/src/meta_schedule/space/postproc.cc index 2c0634ff66..c0e1c599f6 100644 --- a/src/meta_schedule/space/postproc.cc +++ b/src/meta_schedule/space/postproc.cc @@ -1119,7 +1119,7 @@ struct Internal { */ static bool Apply(Postproc self, SearchTask task, Schedule sch, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/strategy/evolutionary.cc b/src/meta_schedule/strategy/evolutionary.cc index 6f71b7c117..9b1c9c6357 100644 --- a/src/meta_schedule/strategy/evolutionary.cc +++ b/src/meta_schedule/strategy/evolutionary.cc @@ -780,7 +780,7 @@ struct Internal { SearchTask task, SearchSpace space, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -800,7 +800,7 @@ struct Internal { static Array EvolveWithCostModel(Evolutionary self, Array inits, SearchTask task, SearchSpace space, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); @@ -820,7 +820,7 @@ struct Internal { SearchTask task, SearchSpace space, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); diff --git a/src/meta_schedule/strategy/mutator.cc b/src/meta_schedule/strategy/mutator.cc index 4a794cfcb8..3ed1b90e2d 100644 --- a/src/meta_schedule/strategy/mutator.cc +++ b/src/meta_schedule/strategy/mutator.cc @@ -482,7 +482,7 @@ struct Internal { static Optional Apply(Mutator mutator, SearchTask task, Trace trace, Optional seed) { tir::TRandState rand_state; - if (seed.defined() && seed.value()->value > 0) { + if (seed.defined() && seed.value()->value != -1) { tir::RandEngine(&rand_state).Seed(seed.value()->value); } else { tir::RandEngine(&rand_state).Seed(std::random_device()()); From c38e52cfc040e2ec5c9d33c02a304345ba4e96e2 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 18:32:54 -0700 Subject: [PATCH 21/23] Revoke test changes. --- tests/python/meta_schedule/test_integration_cuda.py | 4 ++-- tests/python/meta_schedule/test_resnet_end_to_end.py | 6 +++--- tests/python/meta_schedule/test_resnet_end_to_end_cuda.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/meta_schedule/test_integration_cuda.py b/tests/python/meta_schedule/test_integration_cuda.py index 526b1339f9..67c5056f4e 100644 --- a/tests/python/meta_schedule/test_integration_cuda.py +++ b/tests/python/meta_schedule/test_integration_cuda.py @@ -29,8 +29,8 @@ logging.basicConfig() logging.getLogger("meta_schedule").setLevel(logging.DEBUG) -RPC_KEY = "rtx-2080ti" -TARGET = tvm.target.Target("nvidia/geforce-rtx-2080-ti") +RPC_KEY = "rtx-3070" +TARGET = tvm.target.Target("nvidia/geforce-rtx-3070") TARGET_HOST = tvm.target.Target("llvm") SPACE = ms.space.PostOrderApply( stages=[ diff --git a/tests/python/meta_schedule/test_resnet_end_to_end.py b/tests/python/meta_schedule/test_resnet_end_to_end.py index c0ad78838c..c0ada49b01 100644 --- a/tests/python/meta_schedule/test_resnet_end_to_end.py +++ b/tests/python/meta_schedule/test_resnet_end_to_end.py @@ -96,13 +96,13 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): return mod, params, input_shape, output_shape -RPC_KEY = "test" +RPC_KEY = "raspi4b-aarch64" network = "resnet-50" batch_size = 1 layout = "NHWC" -target = tvm.target.Target("llvm") +target = tvm.target.Target("raspberry-pi/4b-64") dtype = "float32" -TARGET_HOST = tvm.target.Target("llvm") +TARGET_HOST = tvm.target.Target("raspberry-pi/4b-64") SPACE = ms.space.PostOrderApply( stages=[ ms.rule.inline_pure_spatial(strict_mode=True), diff --git a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py index 6cc9c8fb77..b8b94ddbc2 100644 --- a/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py +++ b/tests/python/meta_schedule/test_resnet_end_to_end_cuda.py @@ -96,11 +96,11 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): return mod, params, input_shape, output_shape -RPC_KEY = "rtx-2080ti" +RPC_KEY = "rtx-3080" network = "resnet-50" batch_size = 1 layout = "NHWC" -target = tvm.target.Target("nvidia/geforce-rtx-2080-ti") +target = tvm.target.Target("nvidia/geforce-rtx-3080") dtype = "float32" TARGET_HOST = tvm.target.Target("llvm") SPACE = ms.space.PostOrderApply( From b2205662092f59a0310f8c535e71ced7efd02d86 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 18:35:22 -0700 Subject: [PATCH 22/23] Revoke more tests. --- .../test_meta_schedule_layout_rewrite_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py b/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py index eee39ec024..0b56143fab 100644 --- a/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py +++ b/tests/python/meta_schedule/test_meta_schedule_layout_rewrite_network.py @@ -99,9 +99,9 @@ def get_relay_batchmm(batch=4, m=128, n=128, k=128): return mod, data, weight -RPC_KEY = "test" -TARGET = tvm.target.Target("llvm") -TARGET_HOST = tvm.target.Target("llvm") +RPC_KEY = "raspi4b-aarch64" +TARGET = tvm.target.Target("raspberry-pi/4b-64") +TARGET_HOST = tvm.target.Target("raspberry-pi/4b-64") SPACE = ms.space.PostOrderApply( stages=[ ms.rule.inline_pure_spatial(strict_mode=True), From ca6aeefb2e514c0e56f6118fa450929edef7c8cd Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 11 Aug 2021 18:39:25 -0700 Subject: [PATCH 23/23] Move prime table. --- src/tir/schedule/primitive.h | 100 ------------------------- src/tir/schedule/primitive/sampling.cc | 100 +++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 100 deletions(-) diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 99db58c5f8..fff72d5859 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -29,106 +29,6 @@ namespace tvm { namespace tir { -struct PrimeTable { - /*! \brief The table contains prime numbers in [2, kMaxPrime) */ - static constexpr const int kMaxPrime = 65536; - /*! \brief The exact number of prime numbers in the table */ - static constexpr const int kNumPrimes = 6542; - /*! - * \brief For each number in [2, kMaxPrime), the index of its min factor. - * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. - */ - int min_factor_idx[kMaxPrime]; - /*! \brief The prime numbers in [2, kMaxPrime) */ - std::vector primes; - /*! - * \brief The power of each prime number. - * pow_table[i, j] stores the result of pow(prime[i], j + 1) - */ - std::vector> pow_tab; - - /*! \brief Get a global instance of the prime table */ - static const PrimeTable* Global() { - static const PrimeTable table; - return &table; - } - - /*! \brief Constructor, pre-computes all info in the prime table */ - PrimeTable() { - constexpr const int64_t int_max = std::numeric_limits::max(); - // Euler's sieve: prime number in linear time - for (int i = 0; i < kMaxPrime; ++i) { - min_factor_idx[i] = -1; - } - primes.reserve(kNumPrimes); - for (int x = 2; x < kMaxPrime; ++x) { - if (min_factor_idx[x] == -1) { - min_factor_idx[x] = primes.size(); - primes.push_back(x); - } - for (size_t i = 0; i < primes.size(); ++i) { - int factor = primes[i]; - int y = x * factor; - if (y >= kMaxPrime) { - break; - } - min_factor_idx[y] = i; - if (x % factor == 0) { - break; - } - } - } - ICHECK_EQ(static_cast(primes.size()), int(kNumPrimes)); - // Calculate the power table for each prime number - pow_tab.reserve(primes.size()); - for (int prime : primes) { - std::vector tab; - tab.reserve(32); - for (int64_t pow = prime; pow <= int_max; pow *= prime) { - tab.push_back(pow); - } - tab.shrink_to_fit(); - pow_tab.emplace_back(std::move(tab)); - } - } - /*! - * \brief Factorize a number n, and return in a cryptic format - * \param n The number to be factorized - * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] - * For each pair (i, j), we define - * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) - * (primes[i], j) if i != -1 - * Then the factorization is - * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) - */ - std::vector> Factorize(int n) const { - std::vector> result; - result.reserve(16); - int i = 0, n_primes = primes.size(); - // Phase 1: n >= kMaxPrime - for (int j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { - for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { - } - if (j != 0) { - result.emplace_back(i, j); - } - } - // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number - if (n >= kMaxPrime) { - result.emplace_back(-1, n); - return result; - } - // Phase 2: n < kMaxPrime - for (int j; n > 1;) { - int i = min_factor_idx[n]; - for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { - } - result.emplace_back(i, j); - } - return result; - } -}; - /******** Schedule: Sampling ********/ /*! \brief Return a seed that can be used as a new random state. */ diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index ba601382f6..02a16008c9 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -24,6 +24,106 @@ namespace tvm { namespace tir { +struct PrimeTable { + /*! \brief The table contains prime numbers in [2, kMaxPrime) */ + static constexpr const int kMaxPrime = 65536; + /*! \brief The exact number of prime numbers in the table */ + static constexpr const int kNumPrimes = 6542; + /*! + * \brief For each number in [2, kMaxPrime), the index of its min factor. + * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. + */ + int min_factor_idx[kMaxPrime]; + /*! \brief The prime numbers in [2, kMaxPrime) */ + std::vector primes; + /*! + * \brief The power of each prime number. + * pow_table[i, j] stores the result of pow(prime[i], j + 1) + */ + std::vector> pow_tab; + + /*! \brief Get a global instance of the prime table */ + static const PrimeTable* Global() { + static const PrimeTable table; + return &table; + } + + /*! \brief Constructor, pre-computes all info in the prime table */ + PrimeTable() { + constexpr const int64_t int_max = std::numeric_limits::max(); + // Euler's sieve: prime number in linear time + for (int i = 0; i < kMaxPrime; ++i) { + min_factor_idx[i] = -1; + } + primes.reserve(kNumPrimes); + for (int x = 2; x < kMaxPrime; ++x) { + if (min_factor_idx[x] == -1) { + min_factor_idx[x] = primes.size(); + primes.push_back(x); + } + for (size_t i = 0; i < primes.size(); ++i) { + int factor = primes[i]; + int y = x * factor; + if (y >= kMaxPrime) { + break; + } + min_factor_idx[y] = i; + if (x % factor == 0) { + break; + } + } + } + ICHECK_EQ(static_cast(primes.size()), int(kNumPrimes)); + // Calculate the power table for each prime number + pow_tab.reserve(primes.size()); + for (int prime : primes) { + std::vector tab; + tab.reserve(32); + for (int64_t pow = prime; pow <= int_max; pow *= prime) { + tab.push_back(pow); + } + tab.shrink_to_fit(); + pow_tab.emplace_back(std::move(tab)); + } + } + /*! + * \brief Factorize a number n, and return in a cryptic format + * \param n The number to be factorized + * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] + * For each pair (i, j), we define + * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) + * (primes[i], j) if i != -1 + * Then the factorization is + * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) + */ + std::vector> Factorize(int n) const { + std::vector> result; + result.reserve(16); + int i = 0, n_primes = primes.size(); + // Phase 1: n >= kMaxPrime + for (int j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + if (j != 0) { + result.emplace_back(i, j); + } + } + // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number + if (n >= kMaxPrime) { + result.emplace_back(-1, n); + return result; + } + // Phase 2: n < kMaxPrime + for (int j; n > 1;) { + int i = min_factor_idx[n]; + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + result.emplace_back(i, j); + } + return result; + } +}; + TRandState ForkSeed(TRandState* rand_state) { // In order for reproducibility, we computer the new seed using RNG's random state and a // different set of parameters. Note that both 32767 and 1999999973 are prime numbers.