Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[M3a] Sampling Primitives & Random Number Generator #421

Merged
merged 23 commits into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions include/tvm/support/random_engine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file random_engine.h
* \brief Random number generator, for Sampling functions.
*/

#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
#define TVM_SUPPORT_RANDOM_ENGINE_H_

#include <tvm/runtime/logging.h>

#include <cstdint> // for uint64_t

namespace tvm {
namespace support {

/*!
* \brief This linear congruential engine is a drop-in replacement for std::minstd_rand. It strictly
* corresponds to std::minstd_rand and is designed to be platform-independent.
* \note Our linear congruential engine is a complete implementation of
* std::uniform_random_bit_generator so it can be used as generator for any STL random number
* distribution. However, parts of std::linear_congruential_engine's member functions are not
* included for simplification. For full member functions of std::minstd_rand, please check out the
* following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine
*/
class LinearCongruentialEngine {
public:
/*!
* \brief The result type is defined as int64_t here to avoid overflow.
* \note The type name is not in Google style because it is used in STL's distribution inferface.
*/
using result_type = uint64_t;
using TRandState = int64_t;

/*! \brief The multiplier */
static constexpr TRandState multiplier = 48271;

/*! \brief The increment */
static constexpr TRandState increment = 0;

/*! \brief The modulus */
static constexpr TRandState modulus = 2147483647;

/*!
* \brief The minimum possible value of random state here.
* \note The function name is uncapilized because it is used in STL's distribution inferface.
*/
static constexpr result_type min() { return 0; }

/*!
* \brief The maximum possible value of random state here.
* \note The function name is uncapilized because it is used in STL's distribution inferface.
*/
static constexpr result_type max() { return modulus - 1; }

/*!
* \brief Operator to move the random state to the next and return the new random state. According
* to definition of linear congruential engine, the new random state value is computed as
* new_random_state = (current_random_state * multiplier + increment) % modulus.
* \return The next current random state value in the type of result_type.
* \note In order for better efficiency, the implementation here has a few assumptions:
* 1. The multiplication and addition won't overflow.
* 2. The given random state pointer `rand_state_ptr` is not nullptr.
* 3. The given random state `*(rand_state_ptr)` is in the range of [0, modulus - 1].
*/
result_type operator()() {
(*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus;
return *rand_state_ptr_;
}

/*!
* \brief Change the start random state of RNG with the seed of a new random state value.
* \param rand_state The random state given in result_type.
*/
void Seed(TRandState rand_state = 1) {
rand_state %= modulus; // Make sure the seed is within the range of modulus.
if (rand_state == 0)
rand_state = 1; // Avoid getting all 0 given the current parameter set.
else if (rand_state < 0)
rand_state += modulus; // Make sure the rand state is non-negative.
ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null.
*rand_state_ptr_ = rand_state; // Change pointed random state to given random state value.
}

/*!
* \brief Construct a random number generator with a random state pointer.
* \param rand_state_ptr The random state pointer given in result_type*.
* \note The random state is not checked for whether it's nullptr and whether it's in the range of
* [0, modulus-1]. We assume the given random state is valid or the Seed function would be
* called right after the constructor before any usage.
*/
explicit LinearCongruentialEngine(TRandState* rand_state_ptr) {
rand_state_ptr_ = rand_state_ptr;
}

private:
TRandState* rand_state_ptr_;
};

} // namespace support
} // namespace tvm

#endif // TVM_SUPPORT_RANDOM_ENGINE_H_
14 changes: 9 additions & 5 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>

namespace tvm {
namespace tir {

using TRandState = support::LinearCongruentialEngine::TRandState;
using RandEngine = support::LinearCongruentialEngine;

/*! \brief The level of detailed error message rendering */
enum class ScheduleErrorRenderLevel : int32_t {
/*! \brief Render a detailed error message */
Expand Down Expand Up @@ -113,12 +117,12 @@ class ScheduleNode : public runtime::Object {
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* reconstructed
*/
virtual Schedule Copy(int64_t seed = -1) const = 0;
virtual Schedule Copy(tir::TRandState seed = -1) const = 0;
/*!
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
*/
virtual void Seed(int64_t seed = -1) = 0;
virtual void Seed(tir::TRandState seed = -1) = 0;
/*! \brief Fork the random state */
virtual int64_t ForkSeed() = 0;

Expand Down Expand Up @@ -502,11 +506,11 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int64_t seed, int debug_mode,
TVM_DLL static Schedule Concrete(IRModule mod, tir::TRandState seed, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DLL static Schedule Meta(IRModule mod, int64_t seed, int debug_mode,
TVM_DLL static Schedule Meta(IRModule mod, tir::TRandState seed, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DLL static Schedule Traced(IRModule mod, int64_t seed, int debug_mode,
TVM_DLL static Schedule Traced(IRModule mod, tir::TRandState seed, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


def make_error_msg() -> str:
""" Get the error message from traceback. """
"""Get the error message from traceback."""
error_msg = str(traceback.format_exc())
if len(error_msg) > MAX_ERROR_MSG_LEN:
error_msg = (
Expand Down Expand Up @@ -408,7 +408,7 @@ def local_builder_worker(
timeout: int,
verbose: int,
) -> BuildResult.TYPE:
""" Local worker for ProgramBuilder """
"""Local worker for ProgramBuilder"""
# deal with build_func
build_func = {
"tar": build_func_tar.tar, # export to tar
Expand Down Expand Up @@ -603,7 +603,7 @@ def rpc_runner_worker(
f_create_args: Callable[[Device], List[NDArray]],
verbose: int,
) -> MeasureResult.TYPE:
""" RPC worker for ProgramRunner """
"""RPC worker for ProgramRunner"""
measure_input = measure_inputs[index]
build_result = build_results[index]

Expand Down Expand Up @@ -653,13 +653,13 @@ def timed_func():
else:
rpc_eval_repeat = 1
if f_create_args is not None:
args_set = [f_create_args(ctx) for _ in range(rpc_eval_repeat)]
args_set = [f_create_args(dev) for _ in range(rpc_eval_repeat)]
else:
args_set = [
realize_arguments(remote, ctx, measure_input.sch.mod["main"])
realize_arguments(remote, dev, measure_input.sch.mod["main"])
for _ in range(rpc_eval_repeat)
]
ctx.sync()
dev.sync()
costs = sum([time_f(*args).results for args in args_set], ())
# clean up remote files
remote.remove(build_result.filename)
Expand Down
8 changes: 5 additions & 3 deletions src/meta_schedule/autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ namespace tvm {
namespace meta_schedule {

void TuneContextNode::Init(Optional<Integer> seed) {
if (seed.defined()) {
this->sampler.Seed(seed.value()->value);
if (seed.defined() && seed.value()->value != -1) {
tir::RandEngine(&this->rand_state).Seed(seed.value()->value);
} else {
tir::RandEngine(&this->rand_state).Seed(std::random_device()());
}
if (task.defined()) {
task.value()->Init(this);
Expand Down Expand Up @@ -59,7 +61,7 @@ void TuneContextNode::Init(Optional<Integer> seed) {
bool TuneContextNode::Postprocess(const Schedule& sch) {
sch->EnterPostproc();
for (const Postproc& postproc : postprocs) {
if (!postproc->Apply(task.value(), sch, &sampler)) {
if (!postproc->Apply(task.value(), sch, &rand_state)) {
return false;
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/meta_schedule/autotune.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class TuneContextNode : public runtime::Object {
Array<Postproc> postprocs;
Array<MeasureCallback> measure_callbacks;
int num_threads;
Sampler sampler;

tir::TRandState rand_state;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("task", &task);
Expand All @@ -56,7 +57,6 @@ class TuneContextNode : public runtime::Object {
v->Visit("postprocs", &postprocs);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("num_threads", &num_threads);
// `sampler` is not visited
}

void Init(Optional<Integer> seed = NullOpt);
Expand Down Expand Up @@ -93,7 +93,6 @@ class TuneContext : public runtime::ObjectRef {
n->postprocs = postprocs;
n->measure_callbacks = measure_callbacks;
n->num_threads = num_threads;
// `n->sampler` is not initialized
data_ = std::move(n);
(*this)->Init(seed);
}
Expand Down
17 changes: 7 additions & 10 deletions src/meta_schedule/cost_model/rand_cost_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ namespace meta_schedule {
/*! \brief The cost model returning random value for all predictions */
class RandCostModelNode : public CostModelNode {
public:
/*! \brief A sampler for generating random numbers */
Sampler sampler;
/*! \brief A random state for sampling functions to generate random numbers */
tir::TRandState rand_state;

void VisitAttrs(tvm::AttrVisitor* v) {
// sampler is not visited
}
void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief Update the cost model according to new measurement results (training data).
Expand All @@ -48,7 +46,7 @@ class RandCostModelNode : public CostModelNode {
* \return The predicted scores for all states
*/
std::vector<double> Predict(const SearchTask& task, const Array<Schedule>& states) override {
return sampler.SampleUniform(states.size(), 0.0, 1.0);
return tir::SampleUniform(&rand_state, states.size(), 0.0, 1.0);
}

static constexpr const char* _type_key = "meta_schedule.RandCostModel";
Expand All @@ -61,11 +59,10 @@ class RandCostModelNode : public CostModelNode {
*/
class RandCostModel : public CostModel {
public:
RandCostModel() { data_ = make_object<RandCostModelNode>(); }

explicit RandCostModel(int seed) {
explicit RandCostModel(int seed = -1) {
ObjectPtr<RandCostModelNode> n = make_object<RandCostModelNode>();
n->sampler.Seed(seed);
if (seed == -1) seed = std::random_device()();
tir::RandEngine(&n->rand_state).Seed(seed);
data_ = std::move(n);
}

Expand Down
3 changes: 0 additions & 3 deletions src/meta_schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/trace.h>

#include "../tir/schedule/sampler.h"

namespace tvm {
namespace meta_schedule {

using ScheduleNode = tir::TraceNode;
using Schedule = tir::Schedule;
using Sampler = tir::Sampler;
using BlockRV = tir::BlockRV;
using BlockRVNode = tir::BlockRVNode;
using LoopRV = tir::LoopRV;
Expand Down
Loading