diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 7aff6839dc555..3611870c7c9b2 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -139,13 +139,13 @@ class SpaceGenerator : public ObjectRef { public: /*! * \brief Create a design space generator with customized methods on the python-side. - * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`. - * \param generate_design_space_func The packed function of `GenerateDesignSpace`. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_generate_design_space The packed function of `GenerateDesignSpace`. * \return The design space generator created. */ TVM_DLL static SpaceGenerator PySpaceGenerator( - PySpaceGeneratorNode::FInitializeWithTuneContext initialize_with_tune_context_func, - PySpaceGeneratorNode::FGenerateDesignSpace generate_design_space_func); + PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, + PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space); /*! * \brief Create a design space generator that is union of multiple design space generators. diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 0284a55e0d03f..ddd6f4c4815f6 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -249,6 +249,9 @@ class TaskScheduler : public runtime::ObjectRef { * \param builder The builder of the scheduler. * \param runner The runner of the scheduler. * \param database The database of the scheduler. + * \param cost_model The cost model of the scheduler. + * \param measure_callbacks The measure callbacks of the scheduler. + * \return The task scheduler created. */ TVM_DLL static TaskScheduler RoundRobin(Array tasks, // Builder builder, // @@ -256,6 +259,22 @@ class TaskScheduler : public runtime::ObjectRef { Database database, // Optional cost_model, // Optional> measure_callbacks); + /*! + * \brief Create a task scheduler with customized methods on the python-side. + * \param tasks The tasks to be tuned. + * \param builder The builder of the scheduler. + * \param runner The runner of the scheduler. + * \param database The database of the scheduler. + * \param cost_model The cost model of the scheduler. + * \param measure_callbacks The measure callbacks of the scheduler. + * \param f_tune The packed function of `Tune`. + * \param f_initialize_task The packed function of `InitializeTask`. + * \param f_set_task_stopped The packed function of `SetTaskStopped`. + * \param f_is_task_running The packed function of `IsTaskRunning`. + * \param f_join_running_task The packed function of `JoinRunningTask`. + * \param f_next_task_id The packed function of `NextTaskId`. + * \return The task scheduler created. + */ TVM_DLL static TaskScheduler PyTaskScheduler( Array tasks, // Builder builder, // diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 6ef3771fb7833..668fca9ecbbf2 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -81,11 +81,11 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl num_trials_total = 20 strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) - tune_context.space_generator.initialize_with_tune_context(tune_context) - spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + context.space_generator.initialize_with_tune_context(context) + spaces = context.space_generator.generate_design_space(context.mod) - strategy.initialize_with_tune_context(tune_context) + strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] @@ -100,7 +100,7 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(tune_context, candidates, runner_results) + strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [7, 7, 6]