Skip to content

Commit

Permalink
[M3a] Sampling Primitives & Random Number Generator (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh authored Aug 12, 2021
1 parent 5cb602f commit 14632f0
Show file tree
Hide file tree
Showing 32 changed files with 1,183 additions and 1,020 deletions.
121 changes: 121 additions & 0 deletions include/tvm/support/random_engine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file random_engine.h
* \brief Random number generator, for Sampling functions.
*/

#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
#define TVM_SUPPORT_RANDOM_ENGINE_H_

#include <tvm/runtime/logging.h>

#include <cstdint> // for uint64_t

namespace tvm {
namespace support {

/*!
* \brief This linear congruential engine is a drop-in replacement for std::minstd_rand. It strictly
* corresponds to std::minstd_rand and is designed to be platform-independent.
* \note Our linear congruential engine is a complete implementation of
* std::uniform_random_bit_generator so it can be used as generator for any STL random number
* distribution. However, parts of std::linear_congruential_engine's member functions are not
* included for simplification. For full member functions of std::minstd_rand, please check out the
* following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine
*/
class LinearCongruentialEngine {
public:
/*!
* \brief The result type is defined as int64_t here to avoid overflow.
* \note The type name is not in Google style because it is used in STL's distribution inferface.
*/
using result_type = uint64_t;
using TRandState = int64_t;

/*! \brief The multiplier */
static constexpr TRandState multiplier = 48271;

/*! \brief The increment */
static constexpr TRandState increment = 0;

/*! \brief The modulus */
static constexpr TRandState modulus = 2147483647;

/*!
* \brief The minimum possible value of random state here.
* \note The function name is uncapilized because it is used in STL's distribution inferface.
*/
static constexpr result_type min() { return 0; }

/*!
* \brief The maximum possible value of random state here.
* \note The function name is uncapilized because it is used in STL's distribution inferface.
*/
static constexpr result_type max() { return modulus - 1; }

/*!
* \brief Operator to move the random state to the next and return the new random state. According
* to definition of linear congruential engine, the new random state value is computed as
* new_random_state = (current_random_state * multiplier + increment) % modulus.
* \return The next current random state value in the type of result_type.
* \note In order for better efficiency, the implementation here has a few assumptions:
* 1. The multiplication and addition won't overflow.
* 2. The given random state pointer `rand_state_ptr` is not nullptr.
* 3. The given random state `*(rand_state_ptr)` is in the range of [0, modulus - 1].
*/
result_type operator()() {
(*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus;
return *rand_state_ptr_;
}

/*!
* \brief Change the start random state of RNG with the seed of a new random state value.
* \param rand_state The random state given in result_type.
*/
void Seed(TRandState rand_state = 1) {
rand_state %= modulus; // Make sure the seed is within the range of modulus.
if (rand_state == 0)
rand_state = 1; // Avoid getting all 0 given the current parameter set.
else if (rand_state < 0)
rand_state += modulus; // Make sure the rand state is non-negative.
ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null.
*rand_state_ptr_ = rand_state; // Change pointed random state to given random state value.
}

/*!
* \brief Construct a random number generator with a random state pointer.
* \param rand_state_ptr The random state pointer given in result_type*.
* \note The random state is not checked for whether it's nullptr and whether it's in the range of
* [0, modulus-1]. We assume the given random state is valid or the Seed function would be
* called right after the constructor before any usage.
*/
explicit LinearCongruentialEngine(TRandState* rand_state_ptr) {
rand_state_ptr_ = rand_state_ptr;
}

private:
TRandState* rand_state_ptr_;
};

} // namespace support
} // namespace tvm

#endif // TVM_SUPPORT_RANDOM_ENGINE_H_
14 changes: 9 additions & 5 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>

namespace tvm {
namespace tir {

using TRandState = support::LinearCongruentialEngine::TRandState;
using RandEngine = support::LinearCongruentialEngine;

/*! \brief The level of detailed error message rendering */
enum class ScheduleErrorRenderLevel : int32_t {
/*! \brief Render a detailed error message */
Expand Down Expand Up @@ -113,12 +117,12 @@ class ScheduleNode : public runtime::Object {
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* reconstructed
*/
virtual Schedule Copy(int64_t seed = -1) const = 0;
virtual Schedule Copy(tir::TRandState seed = -1) const = 0;
/*!
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
*/
virtual void Seed(int64_t seed = -1) = 0;
virtual void Seed(tir::TRandState seed = -1) = 0;
/*! \brief Fork the random state */
virtual int64_t ForkSeed() = 0;

Expand Down Expand Up @@ -502,11 +506,11 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int64_t seed, int debug_mode,
TVM_DLL static Schedule Concrete(IRModule mod, tir::TRandState seed, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DLL static Schedule Meta(IRModule mod, int64_t seed, int debug_mode,
TVM_DLL static Schedule Meta(IRModule mod, tir::TRandState seed, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DLL static Schedule Traced(IRModule mod, int64_t seed, int debug_mode,
TVM_DLL static Schedule Traced(IRModule mod, tir::TRandState seed, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


def make_error_msg() -> str:
""" Get the error message from traceback. """
"""Get the error message from traceback."""
error_msg = str(traceback.format_exc())
if len(error_msg) > MAX_ERROR_MSG_LEN:
error_msg = (
Expand Down Expand Up @@ -408,7 +408,7 @@ def local_builder_worker(
timeout: int,
verbose: int,
) -> BuildResult.TYPE:
""" Local worker for ProgramBuilder """
"""Local worker for ProgramBuilder"""
# deal with build_func
build_func = {
"tar": build_func_tar.tar, # export to tar
Expand Down Expand Up @@ -603,7 +603,7 @@ def rpc_runner_worker(
f_create_args: Callable[[Device], List[NDArray]],
verbose: int,
) -> MeasureResult.TYPE:
""" RPC worker for ProgramRunner """
"""RPC worker for ProgramRunner"""
measure_input = measure_inputs[index]
build_result = build_results[index]

Expand Down Expand Up @@ -653,13 +653,13 @@ def timed_func():
else:
rpc_eval_repeat = 1
if f_create_args is not None:
args_set = [f_create_args(ctx) for _ in range(rpc_eval_repeat)]
args_set = [f_create_args(dev) for _ in range(rpc_eval_repeat)]
else:
args_set = [
realize_arguments(remote, ctx, measure_input.sch.mod["main"])
realize_arguments(remote, dev, measure_input.sch.mod["main"])
for _ in range(rpc_eval_repeat)
]
ctx.sync()
dev.sync()
costs = sum([time_f(*args).results for args in args_set], ())
# clean up remote files
remote.remove(build_result.filename)
Expand Down
8 changes: 5 additions & 3 deletions src/meta_schedule/autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ namespace tvm {
namespace meta_schedule {

void TuneContextNode::Init(Optional<Integer> seed) {
if (seed.defined()) {
this->sampler.Seed(seed.value()->value);
if (seed.defined() && seed.value()->value != -1) {
tir::RandEngine(&this->rand_state).Seed(seed.value()->value);
} else {
tir::RandEngine(&this->rand_state).Seed(std::random_device()());
}
if (task.defined()) {
task.value()->Init(this);
Expand Down Expand Up @@ -59,7 +61,7 @@ void TuneContextNode::Init(Optional<Integer> seed) {
bool TuneContextNode::Postprocess(const Schedule& sch) {
sch->EnterPostproc();
for (const Postproc& postproc : postprocs) {
if (!postproc->Apply(task.value(), sch, &sampler)) {
if (!postproc->Apply(task.value(), sch, &rand_state)) {
return false;
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/meta_schedule/autotune.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class TuneContextNode : public runtime::Object {
Array<Postproc> postprocs;
Array<MeasureCallback> measure_callbacks;
int num_threads;
Sampler sampler;

tir::TRandState rand_state;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("task", &task);
Expand All @@ -56,7 +57,6 @@ class TuneContextNode : public runtime::Object {
v->Visit("postprocs", &postprocs);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("num_threads", &num_threads);
// `sampler` is not visited
}

void Init(Optional<Integer> seed = NullOpt);
Expand Down Expand Up @@ -93,7 +93,6 @@ class TuneContext : public runtime::ObjectRef {
n->postprocs = postprocs;
n->measure_callbacks = measure_callbacks;
n->num_threads = num_threads;
// `n->sampler` is not initialized
data_ = std::move(n);
(*this)->Init(seed);
}
Expand Down
17 changes: 7 additions & 10 deletions src/meta_schedule/cost_model/rand_cost_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ namespace meta_schedule {
/*! \brief The cost model returning random value for all predictions */
class RandCostModelNode : public CostModelNode {
public:
/*! \brief A sampler for generating random numbers */
Sampler sampler;
/*! \brief A random state for sampling functions to generate random numbers */
tir::TRandState rand_state;

void VisitAttrs(tvm::AttrVisitor* v) {
// sampler is not visited
}
void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief Update the cost model according to new measurement results (training data).
Expand All @@ -48,7 +46,7 @@ class RandCostModelNode : public CostModelNode {
* \return The predicted scores for all states
*/
std::vector<double> Predict(const SearchTask& task, const Array<Schedule>& states) override {
return sampler.SampleUniform(states.size(), 0.0, 1.0);
return tir::SampleUniform(&rand_state, states.size(), 0.0, 1.0);
}

static constexpr const char* _type_key = "meta_schedule.RandCostModel";
Expand All @@ -61,11 +59,10 @@ class RandCostModelNode : public CostModelNode {
*/
class RandCostModel : public CostModel {
public:
RandCostModel() { data_ = make_object<RandCostModelNode>(); }

explicit RandCostModel(int seed) {
explicit RandCostModel(int seed = -1) {
ObjectPtr<RandCostModelNode> n = make_object<RandCostModelNode>();
n->sampler.Seed(seed);
if (seed == -1) seed = std::random_device()();
tir::RandEngine(&n->rand_state).Seed(seed);
data_ = std::move(n);
}

Expand Down
3 changes: 0 additions & 3 deletions src/meta_schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/trace.h>

#include "../tir/schedule/sampler.h"

namespace tvm {
namespace meta_schedule {

using ScheduleNode = tir::TraceNode;
using Schedule = tir::Schedule;
using Sampler = tir::Sampler;
using BlockRV = tir::BlockRV;
using BlockRVNode = tir::BlockRVNode;
using LoopRV = tir::LoopRV;
Expand Down
Loading

0 comments on commit 14632f0

Please sign in to comment.