Skip to content

Commit

Permalink
[MetaSchedule] Add Gradient Based Task Scheduler (apache#10366)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
2 people authored and pfk-beta committed Apr 11, 2022
1 parent 5436907 commit 4e9ea73
Show file tree
Hide file tree
Showing 34 changed files with 894 additions and 323 deletions.
12 changes: 6 additions & 6 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,21 @@ class SearchStrategy : public runtime::ObjectRef {
/*!
* \brief Constructor of replay trace search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for trace replaying.
* \param max_trials_per_task The total number of trials for trace replaying.
*/
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total);
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task);

/*!
* \brief Constructor of replay func search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for func replaying.
* \param max_trials_per_task The total number of trials for func replaying.
*/
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int max_trials_per_task);

/*!
* \brief Constructor of evolutionary search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for evolutionary search.
* \param max_trials_per_task The total number of trials for evolutionary search.
* \param population_size The initial sample population.
* \param init_measured_ratio The ratio of measures samples in initial population.
* \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
Expand All @@ -276,7 +276,7 @@ class SearchStrategy : public runtime::ObjectRef {
* \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
*/
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
int num_trials_total, //
int max_trials_per_task, //
int population_size, //
double init_measured_ratio, //
int init_min_unmeasured, //
Expand Down
94 changes: 52 additions & 42 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ class TaskSchedulerNode : public runtime::Object {
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
Database database{nullptr};
/*! \brief The maximum number of trials allowed. */
int max_trials;
/*! \brief The cost model of the scheduler. */
Optional<CostModel> cost_model;
/*! \brief The list of measure callbacks of the scheduler. */
Array<MeasureCallback> measure_callbacks;
/*! \brief The number of trials already conducted. */
int num_trials_already;

/*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;
Expand All @@ -88,8 +92,10 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
v->Visit("max_trials", &max_trials);
v->Visit("cost_model", &cost_model);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("num_trials_already", &num_trials_already);
}

/*! \brief Auto-tuning. */
Expand All @@ -102,23 +108,16 @@ class TaskSchedulerNode : public runtime::Object {
virtual void InitializeTask(int task_id);

/*!
* \brief Set specific task to be stopped.
* \param task_id The task id to be stopped.
*/
virtual void SetTaskStopped(int task_id);

/*!
* \brief Check whether the task is running.
* \brief Touch the task and update its status
* \param task_id The task id to be checked.
* \return Whether the task is running.
*/
virtual bool IsTaskRunning(int task_id);
virtual void TouchTask(int task_id);

/*!
* \brief Wait until the task is finished.
* \param task_id The task id to be joined.
*/
virtual void JoinRunningTask(int task_id);
virtual Array<RunnerResult> JoinRunningTask(int task_id);

/*!
* \brief Fetch the next task id.
Expand All @@ -142,23 +141,17 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
using FInitializeTask = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `SetTaskStopped` method.
* \param task_id The task id to be stopped.
*/
using FSetTaskStopped = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `IsTaskRunning` method.
* \brief The function type of `TouchTask` method.
* \param task_id The task id to be checked.
* \return Whether the task is running.
*/
using FIsTaskRunning = runtime::TypedPackedFunc<bool(int)>;
using FTouchTask = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `JoinRunningTask` method.
* \param task_id The task id to be joined.
*/
using FJoinRunningTask = runtime::TypedPackedFunc<void(int)>;
using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult>(int)>;

/*!
* \brief The function type of `NextTaskId` method.
Expand All @@ -170,10 +163,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
FTune f_tune;
/*! \brief The packed function to the `InitializeTask` function. */
FInitializeTask f_initialize_task;
/*! \brief The packed function to the `SetTaskStopped` function. */
FSetTaskStopped f_set_task_stopped;
/*! \brief The packed function to the `IsTaskRunning` function. */
FIsTaskRunning f_is_task_running;
/*! \brief The packed function to the `TouchTask` function. */
FTouchTask f_touch_task;
/*! \brief The packed function to the `JoinRunningTask` function. */
FJoinRunningTask f_join_running_task;
/*! \brief The packed function to the `NextTaskId` function. */
Expand All @@ -182,8 +173,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
void VisitAttrs(tvm::AttrVisitor* v) {
// `f_tune` is not visited
// `f_initialize_task` is not visited
// `f_set_task_stopped` is not visited
// `f_is_task_running` is not visited
// `f_touch_task` is not visited
// `f_join_running_task` is not visited
// `f_next_task_id` is not visited
}
Expand All @@ -204,23 +194,15 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
}
}

void SetTaskStopped(int task_id) final {
if (f_set_task_stopped == nullptr) {
TaskSchedulerNode::SetTaskStopped(task_id);
} else {
f_set_task_stopped(task_id);
}
}

bool IsTaskRunning(int task_id) final {
if (f_is_task_running == nullptr) {
return TaskSchedulerNode::IsTaskRunning(task_id);
void TouchTask(int task_id) final {
if (f_touch_task == nullptr) {
return TaskSchedulerNode::TouchTask(task_id);
} else {
return f_is_task_running(task_id);
return f_touch_task(task_id);
}
}

void JoinRunningTask(int task_id) final {
Array<RunnerResult> JoinRunningTask(int task_id) final {
if (f_join_running_task == nullptr) {
return TaskSchedulerNode::JoinRunningTask(task_id);
} else {
Expand Down Expand Up @@ -249,6 +231,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \return The task scheduler created.
Expand All @@ -257,20 +240,47 @@ class TaskScheduler : public runtime::ObjectRef {
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
* \param tasks The tasks to be tuned.
* \param task_weights The weights of each task.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param alpha The parameter alpha to control gradient computation.
* \param window_size The parameter to control backward window size.
* \param seed The random seed.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks,
Array<FloatImm> task_weights, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
double alpha, //
int window_size, //
support::LinearCongruentialEngine::TRandState seed);
/*!
* \brief Create a task scheduler with customized methods on the python-side.
* \param tasks The tasks to be tuned.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param f_tune The packed function of `Tune`.
* \param f_initialize_task The packed function of `InitializeTask`.
* \param f_set_task_stopped The packed function of `SetTaskStopped`.
* \param f_is_task_running The packed function of `IsTaskRunning`.
* \param f_touch_task The packed function of `TouchTask`.
* \param f_join_running_task The packed function of `JoinRunningTask`.
* \param f_next_task_id The packed function of `NextTaskId`.
* \return The task scheduler created.
Expand All @@ -280,12 +290,12 @@ class TaskScheduler : public runtime::ObjectRef {
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
PyTaskSchedulerNode::FTouchTask f_touch_task, //
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TuneContextNode : public runtime::Object {
/*! \brief The task scheduler that owns the tune context */
const TaskSchedulerNode* task_scheduler;
/*! \brief Whether the tuning task has been stopped or finished. */
bool is_stopped;
bool is_terminated;
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;
/*! \brief The building results. */
Expand All @@ -81,7 +81,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_stopped", &is_stopped);
v->Visit("is_terminated", &is_terminated);
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
Expand Down
18 changes: 9 additions & 9 deletions include/tvm/support/random_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ class LinearCongruentialEngine {
* \brief Change the start random state of RNG with the seed of a new random state value.
* \param rand_state The random state given in result_type.
*/
void Seed(TRandState rand_state = 1) {
ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!";
rand_state %= modulus; // Make sure the seed is within the range of modulus.
if (rand_state == 0)
rand_state = 1; // Avoid getting all 0 given the current parameter set.
else if (rand_state < 0)
rand_state += modulus; // Make sure the rand state is non-negative.
ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null.
*rand_state_ptr_ = rand_state; // Change pointed random state to given random state value.
void Seed(TRandState rand_state) {
if (rand_state == -1) {
rand_state = DeviceRandom();
} else if (rand_state == 0) {
rand_state = 1;
}
ICHECK(rand_state >= 0) << "The random state should be nonnegative";
ICHECK(rand_state_ptr_ != nullptr);
*rand_state_ptr_ = rand_state % modulus;
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ScheduleNode : public runtime::Object {
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
*/
virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0;
virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0;
/*! \brief Fork the random state */
virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;

Expand Down
13 changes: 7 additions & 6 deletions python/tvm/meta_schedule/search_strategy/evolutionary_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class EvolutionarySearch(SearchStrategy):
----------
num_trials_per_iter : int
Number of trials per iteration.
num_trials_total : int
max_trials_per_task : int
Total number of trials.
population_size : int
The initial population of traces from measured samples and randomly generated samples.
Expand All @@ -53,7 +53,7 @@ class EvolutionarySearch(SearchStrategy):
"""

num_trials_per_iter: int
num_trials_total: int
max_trials_per_task: int
population_size: int
init_measured_ratio: int
init_min_unmeasured: int
Expand All @@ -66,7 +66,7 @@ def __init__(
self,
*,
num_trials_per_iter: int,
num_trials_total: int,
max_trials_per_task: int,
population_size: int,
init_measured_ratio: float,
init_min_unmeasured: int,
Expand All @@ -79,7 +79,7 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
max_trials_per_task,
population_size,
init_measured_ratio,
init_min_unmeasured,
Expand All @@ -94,7 +94,8 @@ class EvolutionarySearchConfig(NamedTuple):
"""Configuration for EvolutionarySearch"""

num_trials_per_iter: int
num_trials_total: int
max_trials_per_task: int
max_trials_global: int
population_size: int = 2048
init_measured_ratio: float = 0.2
init_min_unmeasured: int = 50
Expand All @@ -106,7 +107,7 @@ class EvolutionarySearchConfig(NamedTuple):
def create_strategy(self) -> EvolutionarySearch:
return EvolutionarySearch(
num_trials_per_iter=self.num_trials_per_iter,
num_trials_total=self.num_trials_total,
max_trials_per_task=self.max_trials_per_task,
population_size=self.population_size,
init_measured_ratio=self.init_measured_ratio,
init_min_unmeasured=self.init_min_unmeasured,
Expand Down
13 changes: 7 additions & 6 deletions python/tvm/meta_schedule/search_strategy/replay_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,32 @@ class ReplayFunc(SearchStrategy):
----------
num_trials_per_iter : int
Number of trials per iteration.
num_trials_total : int
max_trials_per_task : int
Total number of trials.
"""

num_trials_per_iter: int
num_trials_total: int
max_trials_per_task: int

def __init__(
self,
num_trials_per_iter: int,
num_trials_total: int,
max_trials_per_task: int,
):
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
max_trials_per_task,
)


class ReplayFuncConfig(NamedTuple):
"""Configuration for ReplayFunc"""

num_trials_per_iter: int
num_trials_total: int
max_trials_per_task: int
max_trials_global: int

def create_strategy(self) -> ReplayFunc:
return ReplayFunc(self.num_trials_per_iter, self.num_trials_total)
return ReplayFunc(self.num_trials_per_iter, self.max_trials_per_task)
Loading

0 comments on commit 4e9ea73

Please sign in to comment.