Skip to content

Commit

Permalink
Minor fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Dec 22, 2021
1 parent 62d85dc commit cc428ea
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
8 changes: 4 additions & 4 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ class SpaceGenerator : public ObjectRef {
public:
/*!
* \brief Create a design space generator with customized methods on the python-side.
* \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`.
* \param generate_design_space_func The packed function of `GenerateDesignSpace`.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_generate_design_space The packed function of `GenerateDesignSpace`.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PySpaceGenerator(
PySpaceGeneratorNode::FInitializeWithTuneContext initialize_with_tune_context_func,
PySpaceGeneratorNode::FGenerateDesignSpace generate_design_space_func);
PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space);

/*!
* \brief Create a design space generator that is union of multiple design space generators.
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,32 @@ class TaskScheduler : public runtime::ObjectRef {
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
/*!
* \brief Create a task scheduler with customized methods on the python-side.
* \param tasks The tasks to be tuned.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param f_tune The packed function of `Tune`.
* \param f_initialize_task The packed function of `InitializeTask`.
* \param f_set_task_stopped The packed function of `SetTaskStopped`.
* \param f_is_task_running The packed function of `IsTaskRunning`.
* \param f_join_running_task The packed function of `JoinRunningTask`.
* \param f_next_task_id The packed function of `NextTaskId`.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler PyTaskScheduler(
Array<TuneContext> tasks, //
Builder builder, //
Expand Down
10 changes: 5 additions & 5 deletions tests/python/unittest/test_meta_schedule_search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl
num_trials_total = 20

strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total)
tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul))
tune_context.space_generator.initialize_with_tune_context(tune_context)
spaces = tune_context.space_generator.generate_design_space(tune_context.mod)
context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul))
context.space_generator.initialize_with_tune_context(context)
spaces = context.space_generator.generate_design_space(context.mod)

strategy.initialize_with_tune_context(tune_context)
strategy.initialize_with_tune_context(context)
strategy.pre_tuning(spaces)
(correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
num_trials_each_iter: List[int] = []
Expand All @@ -100,7 +100,7 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl
remove_decisions=(isinstance(strategy, ReplayTrace)),
)
runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
strategy.notify_runner_results(tune_context, candidates, runner_results)
strategy.notify_runner_results(context, candidates, runner_results)
candidates = strategy.generate_measure_candidates()
strategy.post_tuning()
assert num_trials_each_iter == [7, 7, 6]
Expand Down

0 comments on commit cc428ea

Please sign in to comment.