Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Add Gradient Based Task Scheduler #10366

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,21 @@ class SearchStrategy : public runtime::ObjectRef {
/*!
* \brief Constructor of replay trace search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for trace replaying.
* \param max_trials_per_task The total number of trials for trace replaying.
*/
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total);
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task);

/*!
* \brief Constructor of replay func search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for func replaying.
* \param max_trials_per_task The total number of trials for func replaying.
*/
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int max_trials_per_task);

/*!
* \brief Constructor of evolutionary search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for evolutionary search.
* \param max_trials_per_task The total number of trials for evolutionary search.
* \param population_size The initial sample population.
* \param init_measured_ratio The ratio of measures samples in initial population.
* \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
Expand All @@ -276,7 +276,7 @@ class SearchStrategy : public runtime::ObjectRef {
* \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
*/
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
int num_trials_total, //
int max_trials_per_task, //
int population_size, //
double init_measured_ratio, //
int init_min_unmeasured, //
Expand Down
94 changes: 52 additions & 42 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ class TaskSchedulerNode : public runtime::Object {
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
Database database{nullptr};
/*! \brief The maximum number of trials allowed. */
int max_trials;
/*! \brief The cost model of the scheduler. */
Optional<CostModel> cost_model;
/*! \brief The list of measure callbacks of the scheduler. */
Array<MeasureCallback> measure_callbacks;
/*! \brief The number of trials already conducted. */
int num_trials_already;

/*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;
Expand All @@ -88,8 +92,10 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
v->Visit("max_trials", &max_trials);
v->Visit("cost_model", &cost_model);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("num_trials_already", &num_trials_already);
}

/*! \brief Auto-tuning. */
Expand All @@ -102,23 +108,16 @@ class TaskSchedulerNode : public runtime::Object {
virtual void InitializeTask(int task_id);

/*!
* \brief Set specific task to be stopped.
* \param task_id The task id to be stopped.
*/
virtual void SetTaskStopped(int task_id);

/*!
* \brief Check whether the task is running.
* \brief Touch the task and update its status
* \param task_id The task id to be checked.
* \return Whether the task is running.
*/
virtual bool IsTaskRunning(int task_id);
virtual void TouchTask(int task_id);

/*!
* \brief Wait until the task is finished.
* \param task_id The task id to be joined.
*/
virtual void JoinRunningTask(int task_id);
virtual Array<RunnerResult> JoinRunningTask(int task_id);

/*!
* \brief Fetch the next task id.
Expand All @@ -142,23 +141,17 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
using FInitializeTask = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `SetTaskStopped` method.
* \param task_id The task id to be stopped.
*/
using FSetTaskStopped = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `IsTaskRunning` method.
* \brief The function type of `TouchTask` method.
* \param task_id The task id to be checked.
* \return Whether the task is running.
*/
using FIsTaskRunning = runtime::TypedPackedFunc<bool(int)>;
using FTouchTask = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `JoinRunningTask` method.
* \param task_id The task id to be joined.
*/
using FJoinRunningTask = runtime::TypedPackedFunc<void(int)>;
using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult>(int)>;

/*!
* \brief The function type of `NextTaskId` method.
Expand All @@ -170,10 +163,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
FTune f_tune;
/*! \brief The packed function to the `InitializeTask` function. */
FInitializeTask f_initialize_task;
/*! \brief The packed function to the `SetTaskStopped` function. */
FSetTaskStopped f_set_task_stopped;
/*! \brief The packed function to the `IsTaskRunning` function. */
FIsTaskRunning f_is_task_running;
/*! \brief The packed function to the `TouchTask` function. */
FTouchTask f_touch_task;
/*! \brief The packed function to the `JoinRunningTask` function. */
FJoinRunningTask f_join_running_task;
/*! \brief The packed function to the `NextTaskId` function. */
Expand All @@ -182,8 +173,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
void VisitAttrs(tvm::AttrVisitor* v) {
// `f_tune` is not visited
// `f_initialize_task` is not visited
// `f_set_task_stopped` is not visited
// `f_is_task_running` is not visited
// `f_touch_task` is not visited
// `f_join_running_task` is not visited
// `f_next_task_id` is not visited
}
Expand All @@ -204,23 +194,15 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
}
}

void SetTaskStopped(int task_id) final {
if (f_set_task_stopped == nullptr) {
TaskSchedulerNode::SetTaskStopped(task_id);
} else {
f_set_task_stopped(task_id);
}
}

bool IsTaskRunning(int task_id) final {
if (f_is_task_running == nullptr) {
return TaskSchedulerNode::IsTaskRunning(task_id);
void TouchTask(int task_id) final {
if (f_touch_task == nullptr) {
return TaskSchedulerNode::TouchTask(task_id);
} else {
return f_is_task_running(task_id);
return f_touch_task(task_id);
}
}

void JoinRunningTask(int task_id) final {
Array<RunnerResult> JoinRunningTask(int task_id) final {
if (f_join_running_task == nullptr) {
return TaskSchedulerNode::JoinRunningTask(task_id);
} else {
Expand Down Expand Up @@ -249,6 +231,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \return The task scheduler created.
Expand All @@ -257,20 +240,47 @@ class TaskScheduler : public runtime::ObjectRef {
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
* \param tasks The tasks to be tuned.
* \param task_weights The weights of each task.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param alpha The parameter alpha to control gradient computation.
* \param window_size The parameter to control backward window size.
* \param seed The random seed.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks,
Array<FloatImm> task_weights, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
double alpha, //
int window_size, //
support::LinearCongruentialEngine::TRandState seed);
/*!
* \brief Create a task scheduler with customized methods on the python-side.
* \param tasks The tasks to be tuned.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param f_tune The packed function of `Tune`.
* \param f_initialize_task The packed function of `InitializeTask`.
* \param f_set_task_stopped The packed function of `SetTaskStopped`.
* \param f_is_task_running The packed function of `IsTaskRunning`.
* \param f_touch_task The packed function of `TouchTask`.
* \param f_join_running_task The packed function of `JoinRunningTask`.
* \param f_next_task_id The packed function of `NextTaskId`.
* \return The task scheduler created.
Expand All @@ -280,12 +290,12 @@ class TaskScheduler : public runtime::ObjectRef {
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
PyTaskSchedulerNode::FTouchTask f_touch_task, //
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TuneContextNode : public runtime::Object {
/*! \brief The task scheduler that owns the tune context */
const TaskSchedulerNode* task_scheduler;
/*! \brief Whether the tuning task has been stopped or finished. */
bool is_stopped;
bool is_terminated;
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;
/*! \brief The building results. */
Expand All @@ -81,7 +81,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_stopped", &is_stopped);
v->Visit("is_terminated", &is_terminated);
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
Expand Down
18 changes: 9 additions & 9 deletions include/tvm/support/random_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ class LinearCongruentialEngine {
* \brief Change the start random state of RNG with the seed of a new random state value.
* \param rand_state The random state given in result_type.
*/
void Seed(TRandState rand_state = 1) {
ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!";
rand_state %= modulus; // Make sure the seed is within the range of modulus.
if (rand_state == 0)
rand_state = 1; // Avoid getting all 0 given the current parameter set.
else if (rand_state < 0)
rand_state += modulus; // Make sure the rand state is non-negative.
ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null.
*rand_state_ptr_ = rand_state; // Change pointed random state to given random state value.
void Seed(TRandState rand_state) {
if (rand_state == -1) {
rand_state = DeviceRandom();
} else if (rand_state == 0) {
rand_state = 1;
}
ICHECK(rand_state >= 0) << "The random state should be nonnegative";
ICHECK(rand_state_ptr_ != nullptr);
*rand_state_ptr_ = rand_state % modulus;
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ScheduleNode : public runtime::Object {
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
*/
virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0;
virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0;
/*! \brief Fork the random state */
virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;

Expand Down
13 changes: 7 additions & 6 deletions python/tvm/meta_schedule/search_strategy/evolutionary_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class EvolutionarySearch(SearchStrategy):
----------
num_trials_per_iter : int
Number of trials per iteration.
num_trials_total : int
max_trials_per_task : int
Total number of trials.
population_size : int
The initial population of traces from measured samples and randomly generated samples.
Expand All @@ -53,7 +53,7 @@ class EvolutionarySearch(SearchStrategy):
"""

num_trials_per_iter: int
num_trials_total: int
max_trials_per_task: int
population_size: int
init_measured_ratio: int
init_min_unmeasured: int
Expand All @@ -66,7 +66,7 @@ def __init__(
self,
*,
num_trials_per_iter: int,
num_trials_total: int,
max_trials_per_task: int,
population_size: int,
init_measured_ratio: float,
init_min_unmeasured: int,
Expand All @@ -79,7 +79,7 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
max_trials_per_task,
population_size,
init_measured_ratio,
init_min_unmeasured,
Expand All @@ -94,7 +94,8 @@ class EvolutionarySearchConfig(NamedTuple):
"""Configuration for EvolutionarySearch"""

num_trials_per_iter: int
num_trials_total: int
max_trials_per_task: int
max_trials_global: int
population_size: int = 2048
init_measured_ratio: float = 0.2
init_min_unmeasured: int = 50
Expand All @@ -106,7 +107,7 @@ class EvolutionarySearchConfig(NamedTuple):
def create_strategy(self) -> EvolutionarySearch:
return EvolutionarySearch(
num_trials_per_iter=self.num_trials_per_iter,
num_trials_total=self.num_trials_total,
max_trials_per_task=self.max_trials_per_task,
population_size=self.population_size,
init_measured_ratio=self.init_measured_ratio,
init_min_unmeasured=self.init_min_unmeasured,
Expand Down
13 changes: 7 additions & 6 deletions python/tvm/meta_schedule/search_strategy/replay_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,32 @@ class ReplayFunc(SearchStrategy):
----------
num_trials_per_iter : int
Number of trials per iteration.
num_trials_total : int
max_trials_per_task : int
Total number of trials.
"""

num_trials_per_iter: int
num_trials_total: int
max_trials_per_task: int

def __init__(
self,
num_trials_per_iter: int,
num_trials_total: int,
max_trials_per_task: int,
):
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
max_trials_per_task,
)


class ReplayFuncConfig(NamedTuple):
"""Configuration for ReplayFunc"""

num_trials_per_iter: int
num_trials_total: int
max_trials_per_task: int
max_trials_global: int

def create_strategy(self) -> ReplayFunc:
return ReplayFunc(self.num_trials_per_iter, self.num_trials_total)
return ReplayFunc(self.num_trials_per_iter, self.max_trials_per_task)
Loading