diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 19358552df..b809843f41 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -137,6 +137,7 @@ class PyBuilderNode : public BuilderNode { } Array Build(const Array& build_inputs) final { + ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!"; return f_build(build_inputs); } diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 7ba3c207e3..60c6898f00 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -230,18 +230,29 @@ class PyDatabaseNode : public DatabaseNode { // `f_size` is not visited } - static constexpr const char* _type_key = "meta_schedule.PyDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); - - Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); } + Workload CommitWorkload(const IRModule& mod) final { + ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!"; + return f_commit_workload(mod); + } - void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); } + void CommitTuningRecord(const TuningRecord& record) final { + ICHECK(f_commit_tuning_record != nullptr) + << "PyDatabase's CommitTuningRecord method not implemented!"; + f_commit_tuning_record(record); + } Array GetTopK(const Workload& workload, int top_k) final { + ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!"; return f_get_top_k(workload, top_k); } - int64_t Size() final { return f_size(); } + int64_t Size() final { + ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!"; + return f_size(); + } + + static constexpr const char* _type_key = "meta_schedule.PyDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); }; /*! diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index e1a4a02265..9ee7039959 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -96,6 +96,7 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { const Array& measure_candidates, // const Array& builds, // const Array& results) final { + ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results); } diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index ccbfff0ff9..82f5b76834 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -86,10 +86,15 @@ class PyMutatorNode : public MutatorNode { } void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyMutator's InitializeWithTuneContext method not implemented!"; this->f_initialize_with_tune_context(context); } - Optional Apply(const tir::Trace& trace) final { return this->f_apply(trace); } + Optional Apply(const tir::Trace& trace) final { + ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; + return this->f_apply(trace); + } static constexpr const char* _type_key = "meta_schedule.PyMutator"; TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 3ac6f1c4fe..c24861d697 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -91,10 +91,15 @@ class PyPostprocNode : public PostprocNode { } void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyPostproc's InitializeWithTuneContext method not implemented!"; this->f_initialize_with_tune_context(context); } - bool Apply(const tir::Schedule& sch) final { return this->f_apply(sch); } + bool Apply(const tir::Schedule& sch) final { + ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; + return this->f_apply(sch); + } static constexpr const char* _type_key = "meta_schedule.PyPostproc"; TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index c1451ae977..b154195f43 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -207,7 +207,10 @@ class PyRunnerNode : public RunnerNode { // `f_run` is not visited } - Array Run(Array runner_inputs) final { return f_run(runner_inputs); } + Array Run(Array runner_inputs) final { + ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; + return f_run(runner_inputs); + } static constexpr const char* _type_key = "meta_schedule.PyRunner"; TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode); diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index af4ab91ba9..92aa46beea 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -90,10 +90,13 @@ class PyScheduleRuleNode : public ScheduleRuleNode { } void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyScheduleRule's InitializeWithTuneContext method not implemented!"; this->f_initialize_with_tune_context(context); } Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final { + ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; return this->f_apply(sch, block); } diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index cbac016c3c..3a0fa0ab4a 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -188,20 +188,30 @@ class PySearchStrategyNode : public SearchStrategyNode { } void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PySearchStrategy's InitializeWithTuneContext method not implemented!"; this->f_initialize_with_tune_context(context); } void PreTuning(const Array& design_spaces) final { + ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; this->f_pre_tuning(design_spaces); } - void PostTuning() final { this->f_post_tuning(); } + void PostTuning() final { + ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!"; + this->f_post_tuning(); + } Optional> GenerateMeasureCandidates() final { + ICHECK(f_generate_measure_candidates != nullptr) + << "PySearchStrategy's GenerateMeasureCandidates method not implemented!"; return this->f_generate_measure_candidates(); } void NotifyRunnerResults(const Array& results) final { + ICHECK(f_notify_runner_results != nullptr) + << "PySearchStrategy's NotifyRunnerResults method not implemented!"; this->f_notify_runner_results(results); } diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 27537a9211..a0dfede820 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -113,10 +113,14 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { } void InitializeWithTuneContext(const TuneContext& tune_context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PySpaceGenerator's InitializeWithTuneContext !"; f_initialize_with_tune_context(tune_context); } Array GenerateDesignSpace(const IRModule& mod) final { + ICHECK(f_generate_design_space != nullptr) + << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; return f_generate_design_space(mod); } diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 9c592b2794..062f493ec3 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -126,6 +126,8 @@ class TaskSchedulerNode : public runtime::Object { TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object); }; +class TaskScheduler; + /*! \brief The task scheduler with customized methods on the python-side. */ class PyTaskSchedulerNode : public TaskSchedulerNode { public: @@ -183,26 +185,47 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { } void Tune() final { // - f_tune(); + if (f_tune == nullptr) { + TaskSchedulerNode::Tune(); + } else { + f_tune(); + } } void InitializeTask(int task_id) final { // - f_initialize_task(task_id); + if (f_initialize_task == nullptr) { + TaskSchedulerNode::InitializeTask(task_id); + } else { + f_initialize_task(task_id); + } } void SetTaskStopped(int task_id) final { // - f_set_task_stopped(task_id); + if (f_set_task_stopped == nullptr) { + TaskSchedulerNode::SetTaskStopped(task_id); + } else { + f_set_task_stopped(task_id); + } } bool IsTaskRunning(int task_id) final { // - return f_is_task_running(task_id); + if (f_is_task_running == nullptr) { + return TaskSchedulerNode::IsTaskRunning(task_id); + } else { + return f_is_task_running(task_id); + } } void JoinRunningTask(int task_id) final { // - f_join_running_task(task_id); + if (f_join_running_task == nullptr) { + return TaskSchedulerNode::JoinRunningTask(task_id); + } else { + return f_join_running_task(task_id); + } } int NextTaskId() final { // + ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!"; return f_next_task_id(); } diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index 733e3ce16e..381051e85f 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -23,6 +23,7 @@ from tvm.target import Target from .. import _ffi_api +from ..utils import check_override @register_object("meta_schedule.BuilderInput") @@ -119,6 +120,7 @@ class PyBuilder(Builder): def __init__(self): """Constructor.""" + @check_override(self.__class__, Builder) def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]: return self.build(build_inputs) diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index f676c2ff50..fd746e640c 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -25,7 +25,7 @@ from .. import _ffi_api from ..arg_info import ArgInfo -from ..utils import _json_de_tvm +from ..utils import _json_de_tvm, check_override @register_object("meta_schedule.Workload") @@ -207,15 +207,19 @@ class PyDatabase(Database): def __init__(self): """Constructor.""" + @check_override(self.__class__, Database) def f_commit_workload(mod: IRModule) -> Workload: return self.commit_workload(mod) + @check_override(self.__class__, Database) def f_commit_tuning_record(record: TuningRecord) -> None: self.commit_tuning_record(record) + @check_override(self.__class__, Database) def f_get_top_k(workload: Workload, top_k: int) -> List[TuningRecord]: return self.get_top_k(workload, top_k) + @check_override(self.__class__, Database, func_name="__len__") def f_size() -> int: return len(self) @@ -225,4 +229,4 @@ def f_size() -> int: f_commit_tuning_record, f_get_top_k, f_size, - ) \ No newline at end of file + ) diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py index 8d487905ea..f7daed55f6 100644 --- a/python/tvm/meta_schedule/measure_callback/measure_callback.py +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -20,16 +20,16 @@ from tvm._ffi import register_object from tvm.runtime import Object -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.search_strategy import MeasureCandidate -from tvm.meta_schedule.builder import BuilderResult -from tvm.meta_schedule.runner import RunnerResult -from tvm.meta_schedule.utils import _get_hex_address + +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..builder import BuilderResult +from ..runner import RunnerResult +from ..utils import _get_hex_address, check_override from .. import _ffi_api if TYPE_CHECKING: - from ..tune_context import TuneContext from ..task_scheduler import TaskScheduler @@ -77,6 +77,7 @@ class PyMeasureCallback(MeasureCallback): def __init__(self): """Constructor.""" + @check_override(self.__class__, MeasureCallback) def f_apply( task_scheduler: "TaskScheduler", tasks: List[TuneContext], diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index e33431e15c..f583154fec 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -21,7 +21,7 @@ from tvm.runtime import Object from tvm.tir.schedule import Trace -from ..utils import _get_hex_address +from ..utils import _get_hex_address, check_override from .. import _ffi_api if TYPE_CHECKING: @@ -66,9 +66,11 @@ class PyMutator(Mutator): def __init__(self): """Constructor.""" + @check_override(self.__class__, Mutator) def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) + @check_override(self.__class__, Mutator) def f_apply(trace: Trace) -> Optional[Trace]: return self.apply(trace) diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index 930a04f471..06e0da8fd3 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -18,12 +18,12 @@ from typing import TYPE_CHECKING -from tvm._ffi import register_object, register_func +from tvm._ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import Schedule -from tvm.meta_schedule.utils import _get_hex_address from .. import _ffi_api +from ..utils import _get_hex_address, check_override if TYPE_CHECKING: from ..tune_context import TuneContext @@ -75,9 +75,11 @@ class PyPostproc(Postproc): def __init__(self): """Constructor.""" + @check_override(self.__class__, Postproc) def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) + @check_override(self.__class__, Postproc) def f_apply(sch: Schedule) -> bool: return self.apply(sch) diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 6f92554055..71a557dca3 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -22,6 +22,7 @@ from .. import _ffi_api from ..arg_info import ArgInfo +from ..utils import check_override @register_object("meta_schedule.RunnerInput") @@ -158,6 +159,7 @@ class PyRunner(Runner): def __init__(self) -> None: """Constructor""" + @check_override(self.__class__, Runner) def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: return self.run(runner_inputs) diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 1e10fe4183..ec101410f6 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -24,7 +24,7 @@ from tvm.runtime import Object from tvm.tir.schedule import Schedule, BlockRV -from ..utils import _get_hex_address +from ..utils import _get_hex_address, check_override from .. import _ffi_api if TYPE_CHECKING: @@ -72,9 +72,11 @@ class PyScheduleRule(ScheduleRule): def __init__(self): """Constructor.""" + @check_override(self.__class__, ScheduleRule) def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) + @check_override(self.__class__, ScheduleRule) def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]: return self.apply(sch, block) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 6e4b4a43cd..02cc80b844 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -27,6 +27,7 @@ from .. import _ffi_api from ..arg_info import ArgInfo from ..runner import RunnerResult +from ..utils import check_override if TYPE_CHECKING: from ..tune_context import TuneContext @@ -129,18 +130,23 @@ class PySearchStrategy(SearchStrategy): def __init__(self): """Constructor.""" + @check_override(self.__class__, SearchStrategy) def f_initialize_with_tune_context(context: "TuneContext") -> None: self.initialize_with_tune_context(context) + @check_override(self.__class__, SearchStrategy) def f_pre_tuning(design_spaces: List[Schedule]) -> None: self.pre_tuning(design_spaces) + @check_override(self.__class__, SearchStrategy) def f_post_tuning() -> None: self.post_tuning() + @check_override(self.__class__, SearchStrategy) def f_generate_measure_candidates() -> List[MeasureCandidate]: return self.generate_measure_candidates() + @check_override(self.__class__, SearchStrategy) def f_notify_runner_results(results: List["RunnerResult"]) -> None: self.notify_runner_results(results) diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 0ce654ff58..2172613ce1 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -26,6 +26,7 @@ from tvm.tir.schedule import Schedule from .. import _ffi_api +from ..utils import check_override if TYPE_CHECKING: from ..tune_context import TuneContext @@ -70,9 +71,11 @@ class PySpaceGenerator(SpaceGenerator): def __init__(self): """Constructor.""" + @check_override(self.__class__, SpaceGenerator) def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) + @check_override(self.__class__, SpaceGenerator) def f_generate_design_space(mod: IRModule) -> List[Schedule]: return self.generate_design_space(mod) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index 1e4042549e..2af3852ab7 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -27,6 +27,7 @@ from ..database import Database from ..tune_context import TuneContext from .. import _ffi_api +from ..utils import check_override @register_object("meta_schedule.TaskScheduler") @@ -57,6 +58,16 @@ def tune(self) -> None: """Auto-tuning.""" _ffi_api.TaskSchedulerTune(self) # pylint: disable=no-member + def next_task_id(self) -> int: + """Fetch the next task id. + + Returns + ------- + int + The next task id. + """ + return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member + def _initialize_task(self, task_id: int) -> None: """Initialize modules of the given task. @@ -102,16 +113,6 @@ def _join_running_task(self, task_id: int) -> None: """ _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member - def _next_task_id(self) -> int: - """Fetch the next task id. - - Returns - ------- - int - The next task id. - """ - return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member - @register_object("meta_schedule.PyTaskScheduler") class PyTaskScheduler(TaskScheduler): @@ -141,24 +142,38 @@ def __init__( The list of measure callbacks of the scheduler. """ + @check_override(self.__class__, TaskScheduler, required=False) def f_tune() -> None: self.tune() + @check_override(self.__class__, TaskScheduler) + def f_next_task_id() -> int: + return self.next_task_id() + + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_initialize_task" + ) def f_initialize_task(task_id: int) -> None: self._initialize_task(task_id) + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_set_task_stopped" + ) def f_set_task_stopped(task_id: int) -> None: self._set_task_stopped(task_id) + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_is_task_running" + ) def f_is_task_running(task_id: int) -> bool: return self._is_task_running(task_id) + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_join_running_task" + ) def f_join_running_task(task_id: int) -> None: self._join_running_task(task_id) - def f_next_task_id() -> int: - return self._next_task_id() - self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerPyTaskScheduler, # pylint: disable=no-member tasks, diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 8e1cdfbd71..6989ada79a 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -23,6 +23,7 @@ import psutil import tvm +from tvm import meta_schedule from tvm._ffi import get_global_func, register_func from tvm.error import TVMError from tvm.ir import Array, Map, IRModule @@ -61,7 +62,8 @@ def cpu_count(logical: bool = True) -> int: def get_global_func_with_default_on_worker( - name: Union[None, str, Callable], default: Callable, + name: Union[None, str, Callable], + default: Callable, ) -> Callable: """Get the registered global function on the worker process. @@ -97,7 +99,9 @@ def get_global_func_with_default_on_worker( def get_global_func_on_rpc_session( - session: RPCSession, name: str, extra_error_msg: Optional[str] = None, + session: RPCSession, + name: str, + extra_error_msg: Optional[str] = None, ) -> PackedFunc: """Get a PackedFunc from the global registry from an RPCSession. @@ -212,10 +216,51 @@ def _get_hex_address(handle: ctypes.c_void_p) -> str: ---------- handle : ctypes.c_void_p The handle to be converted. - + Returns ------- result : str The hexadecimal address of the handle. """ return hex(ctypes.cast(handle, ctypes.c_void_p).value) + + +def check_override( + derived_class: Any, base_class: Any, required: bool = True, func_name: str = None +) -> Callable: + """Check if the derived class has overrided the base class's method. + + Parameters + ---------- + derived_class : Any + The derived class. + base_class : Any + The base class of derived class. + required : bool + If the method override is required. + func_name : str + Name of the method. Default value None, which would be set to substring of the given + function, e.g. `f_generate`->`generate`. + + Returns + ------- + func : Callable + Raise NotImplementedError if the function is required and not overrided. If the + function is not overrided return None, other return the overrided function. + """ + + def inner(func: Callable): + + if func_name is None: + method = func.__name__[2:] + else: + method = func_name + + if getattr(derived_class, method) is getattr(base_class, method): + if required: + raise NotImplementedError(f"{derived_class}'s {method} method is not implemented!") + else: + return None + return func + + return inner diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index 8caee9447d..733d118c73 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -34,7 +34,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) const auto* self = n.as(); ICHECK(self); PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr); + ICHECK(f_as_string != nullptr) << "PyMeasureCallback's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 9256b95227..9bf6161b55 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -37,7 +37,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) const auto* self = n.as(); ICHECK(self); PyMutatorNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr); + ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index d782a67c2d..ff069e2c68 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -37,7 +37,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) const auto* self = n.as(); ICHECK(self); PyPostprocNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr); + ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 94de0417cb..f80f684daf 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -37,7 +37,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) const auto* self = n.as(); ICHECK(self); PyScheduleRuleNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr); + ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index 9b7df907b0..11886acde6 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -109,6 +109,16 @@ def apply( def test_meta_schedule_measure_callback_as_string(): class NotSoFancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: "TaskScheduler", + tasks: List["TuneContext"], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + pass + def __str__(self) -> str: return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})" diff --git a/tests/python/unittest/test_meta_schedule_mutator.py b/tests/python/unittest/test_meta_schedule_mutator.py index f1656239d1..aedc6d2658 100644 --- a/tests/python/unittest/test_meta_schedule_mutator.py +++ b/tests/python/unittest/test_meta_schedule_mutator.py @@ -51,6 +51,9 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: def test_meta_schedule_mutator(): class FancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + def apply(self, trace: Trace) -> Optional[Trace]: return Trace(trace.insts, {}) @@ -65,6 +68,12 @@ def apply(self, trace: Trace) -> Optional[Trace]: def test_meta_schedule_mutator_as_string(): class YetAnotherFancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + pass + def __str__(self) -> str: return f"YetAnotherFancyMutator({_get_hex_address(self.handle)})" @@ -76,4 +85,3 @@ def __str__(self) -> str: if __name__ == "__main__": test_meta_schedule_mutator() test_meta_schedule_mutator_as_string() - diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py index fc7926dc4a..34f9ffb885 100644 --- a/tests/python/unittest/test_meta_schedule_postproc.py +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -64,6 +64,9 @@ def schedule_matmul(sch: Schedule): def test_meta_schedule_postproc(): class FancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + def apply(self, sch: Schedule) -> bool: schedule_matmul(sch) return True @@ -81,6 +84,9 @@ def apply(self, sch: Schedule) -> bool: def test_meta_schedule_postproc_fail(): class FailingPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + def apply(self, sch: Schedule) -> bool: return False @@ -91,6 +97,12 @@ def apply(self, sch: Schedule) -> bool: def test_meta_schedule_postproc_as_string(): class NotSoFancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + pass + def __str__(self) -> str: return f"NotSoFancyPostproc({_get_hex_address(self.handle)})" diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule.py index a8195080f4..a7480e22b3 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule.py @@ -59,6 +59,9 @@ def _check_correct(schedule: Schedule): def test_meta_schedule_schedule_rule(): class FancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: i, j, k = sch.get_loops(block=block) i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) @@ -81,6 +84,12 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: def test_meta_schedule_schedule_rule_as_string(): class YetStillSomeFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]: + pass + def __str__(self) -> str: return f"YetStillSomeFancyScheduleRule({_get_hex_address(self.handle)})" diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 39bb1acf06..77dd1ec2b2 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -23,6 +23,9 @@ import pytest import tvm +from tvm._ffi.base import TVMError +from tvm.ir.module import IRModule +from tvm.meta_schedule.space_generator.space_generator import PySpaceGenerator from tvm.script import tir as T from tvm.tir.schedule import Schedule @@ -84,5 +87,10 @@ def test_meta_schedule_design_space_generator_union(): _check_correct(design_space) +def test_meta_schedule_design_space_generator_NIE(): + with pytest.raises(NotImplementedError): + PySpaceGenerator() + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index a304096965..65d8b8e161 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -24,17 +24,18 @@ import pytest import tvm +from tvm._ffi.base import TVMError from tvm.script import tir as T from tvm.ir import IRModule -from tvm.tir import Schedule +from tvm.tir import Schedule, schedule from tvm.meta_schedule import TuneContext from tvm.meta_schedule.space_generator import ScheduleFn from tvm.meta_schedule.search_strategy import ReplayTrace from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.task_scheduler import RoundRobin -from tvm.meta_schedule.utils import structural_hash +from tvm.meta_schedule.task_scheduler import RoundRobin, PyTaskScheduler +from tvm.tir.expr import Not # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -216,5 +217,77 @@ def test_meta_schedule_task_scheduler_multiple(): assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total +def test_meta_schedule_task_scheduler_NIE(): + class MyTaskScheduler(PyTaskScheduler): + pass + + with pytest.raises(NotImplementedError): + MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase()) + + +def test_meta_schedule_task_scheduler_override_next_task_id_only(): + class MyTaskScheduler(PyTaskScheduler): + done = set() + + def next_task_id(self) -> int: + while len(self.done) != len(tasks): + x = random.randint(0, len(tasks) - 1) + task = tasks[x] + if not task.is_stopped: + """Calling base func via following route: + Python side: + PyTaskScheduler does not have `_is_task_running` + Call TaskScheduler's `is_task_running`, which calls ffi + C++ side: + The ffi calls TaskScheduler's `is_task_running` + But it is overrided in PyTaskScheduler + PyTaskScheduler checks if the function is overrided in python + If not, it returns the TaskScheduler's vtable, calling + TaskScheduler::IsTaskRunning + """ + if self._is_task_running(x): + # Same Here + self._join_running_task(x) + return x + else: + self.done.add(x) + return -1 + + num_trials_per_iter = 6 + num_trials_total = 101 + tasks = [ + TuneContext( + MatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="Matmul", + rand_state=42, + ), + TuneContext( + MatmulReluModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="MatmulRelu", + rand_state=0xDEADBEEF, + ), + TuneContext( + BatchMatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="BatchMatmul", + rand_state=0x114514, + ), + ] + database = DummyDatabase() + scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database) + scheduler.tune() + assert len(database) == num_trials_total * len(tasks) + for task in tasks: + assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))