diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 941dae4336..cbac016c3c 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tvm { @@ -237,6 +238,13 @@ class SearchStrategy : public runtime::ObjectRef { */ TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + /*! + * \brief Constructor of replay func search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for func replaying. + */ + TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index ed81f4c0d3..733e3ce16e 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -126,6 +126,3 @@ def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]: _ffi_api.BuilderPyBuilder, # type: ignore # pylint: disable=no-member f_build, ) - - def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 3d05441fe2..f676c2ff50 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -225,16 +225,4 @@ def f_size() -> int: f_commit_tuning_record, f_get_top_k, f_size, - ) - - def commit_workload(self, mod: IRModule) -> Workload: - raise NotImplementedError - - def commit_tuning_record(self, record: TuningRecord) -> None: - raise NotImplementedError - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - raise NotImplementedError - - def __len__(self) -> int: - raise NotImplementedError + ) \ No newline at end of file diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index af998cfa67..4e675e4ea7 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -82,11 +82,5 @@ def f_as_string() -> str: f_as_string, ) - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def apply(self, trace: Trace) -> Optional[Trace]: - raise NotImplementedError - def __str__(self) -> str: return f"PyMutator({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index 5eb57d4384..c998ffc028 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -91,11 +91,5 @@ def f_as_string() -> str: f_as_string, ) - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def apply(self, sch: Schedule) -> bool: - raise NotImplementedError - def __str__(self) -> str: return f"PyPostproc({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 9f7be8ea4a..6f92554055 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -165,6 +165,3 @@ def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: _ffi_api.RunnerPyRunner, # type: ignore # pylint: disable=no-member f_run, ) - - def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 1d60eb7b3c..1d90622c28 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -88,11 +88,5 @@ def f_as_string() -> str: f_as_string, ) - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: - raise NotImplementedError - def __str__(self) -> str: return f"PyScheduleRule({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 609baa2677..f6fd50c15a 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -22,3 +22,4 @@ from .search_strategy import SearchStrategy, PySearchStrategy from .replay_trace import ReplayTrace +from .replay_func import ReplayFunc diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py new file mode 100644 index 0000000000..8edd74ab02 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" + +from tvm._ffi import register_object +from .search_strategy import SearchStrategy +from .. import _ffi_api + + +@register_object("meta_schedule.ReplayFunc") +class ReplayFunc(SearchStrategy): + """ + Replay Func Search Strategy is a search strategy that generates measure candidates by + calling a design space generator and transform the design space. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__( + self, + num_trials_per_iter: int, + num_trials_total: int, + ): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyReplayFunc, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 1579f56a90..6e4b4a43cd 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -152,18 +152,3 @@ def f_notify_runner_results(results: List["RunnerResult"]) -> None: f_generate_measure_candidates, f_notify_runner_results, ) - - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def pre_tuning(self, design_spaces: List[Schedule]) -> None: - raise NotImplementedError - - def post_tuning(self) -> None: - raise NotImplementedError - - def generate_measure_candidates(self) -> List[MeasureCandidate]: - raise NotImplementedError - - def notify_runner_results(self, results: List["RunnerResult"]) -> None: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index af7462ea21..0ce654ff58 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -81,9 +81,3 @@ def f_generate_design_space(mod: IRModule) -> List[Schedule]: f_initialize_with_tune_context, f_generate_design_space, ) - - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def generate_design_space(self, mod) -> List[Schedule]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index e6ef2bb15f..531203f375 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -119,21 +119,3 @@ def f_next_task_id() -> int: f_join_running_task, f_next_task_id, ) - - def tune(self) -> None: - raise NotImplementedError() - - def _initialize_task(self, task_id: int) -> None: - raise _ffi_api.TaskSchedulerInitializeTask(self, task_id) - - def _set_task_stopped(self, task_id: int) -> None: - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member - - def _is_task_running(self, task_id: int) -> bool: - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member - - def _join_running_task(self, task_id: int) -> None: - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member - - def _next_task_id(self) -> int: - return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index d075bf2a35..af21908639 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -77,8 +77,8 @@ class TuneContext(Object): mod: Optional[IRModule] target: Optional[Target] - space_generator: "SpaceGenerator" - search_strategy: "SearchStrategy" + space_generator: Optional["SpaceGenerator"] + search_strategy: Optional["SearchStrategy"] sch_rules: List["ScheduleRule"] postproc: List["Postproc"] mutator: List["Mutator"] diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc new file mode 100644 index 0000000000..5c00b3dc04 --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using space generator. */ +class ReplayFuncNode : public SearchStrategyNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayFuncNode* self; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The space generator for measure candidates generation. */ + SpaceGenerator space_generator_{nullptr}; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `space_generator_` is not visited + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->space_generator_ = tune_context->space_generator.value(); + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + Array result; + for (int i = st; i < ed; i++) { + Array schs = self->space_generator_->GenerateDesignSpace(self->mod_); + result.push_back(MeasureCandidate(schs[tir::SampleInt(&self->rand_state_, 0, schs.size())], + self->args_info_)); + } + return result; +} + +inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayFuncNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") + .set_body_typed(SearchStrategy::ReplayFunc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 18ffd8e44f..c4ee3c4679 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -82,7 +82,7 @@ class ReplayTraceNode : public SearchStrategyNode { this->mod_.push_back(DeepCopyIRModule(tune_context->mod.value())); } - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_[0])); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); this->rand_state_ = ForkSeed(&tune_context->rand_state); this->state_.reset(); } diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index e128713915..03b27b914e 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -26,7 +26,7 @@ from tvm.meta_schedule import TuneContext from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace +from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace, ReplayFunc from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -54,9 +54,13 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: - trace_1 = Trace(sch_1.trace.insts, {}) - trace_2 = Trace(sch_2.trace.insts, {}) +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool: + if remove_decisions: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + else: + trace_1 = sch_1.trace + trace_2 = sch_2.trace return str(trace_1) == str(trace_2) @@ -70,29 +74,35 @@ def _schedule_matmul(sch: Schedule): sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) -def test_meta_schedule_replay_trace(): +@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) +def test_meta_schedule_replay_func(TestClass: SearchStrategy): num_trials_per_iter = 7 num_trials_total = 20 - (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) - replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul) - replay.initialize_with_tune_context(tune_context) - - num_trials_each_round: List[int] = [] - replay.pre_tuning([example_sch]) - while True: - candidates = replay.generate_measure_candidates() - if candidates is None: - break - num_trials_each_round.append(len(candidates)) + strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: - assert _is_trace_equal(candidate.sch, example_sch) - runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) - replay.notify_runner_results(runner_results) - replay.post_tuning() - assert num_trials_each_round == [7, 7, 6] + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(type(strategy) == ReplayTrace), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + assert num_trials_each_iter == [7, 7, 6] if __name__ == "__main__":