diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 298cdae4283a..f385b72db46d 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -19,5 +19,7 @@ Meta Schedule search strategy utilizes the design spaces given to generate measure candidates. """ -from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy -from .replay_trace import ReplayTrace + +from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate +from .replay_trace import ReplayTrace, ReplayTraceConfig +from .replay_func import ReplayFunc, ReplayFuncConfig diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py new file mode 100644 index 000000000000..eacc2776fcbb --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" +from typing import NamedTuple + +from tvm._ffi import register_object + +from .. import _ffi_api +from .search_strategy import SearchStrategy + + +@register_object("meta_schedule.ReplayFunc") +class ReplayFunc(SearchStrategy): + """ + Replay Func Search Strategy is a search strategy that generates measure candidates by + calling a design space generator and transform the design space. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__( + self, + num_trials_per_iter: int, + num_trials_total: int, + ): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) + + +class ReplayFuncConfig(NamedTuple): + """Configuration for ReplayFunc""" + + num_trials_per_iter: int + num_trials_total: int + + def create_strategy(self) -> ReplayFunc: + return ReplayFunc(self.num_trials_per_iter, self.num_trials_total) diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc new file mode 100644 index 000000000000..27383adf84e0 --- /dev/null +++ b/src/meta_schedule/mutator/mutator.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Mutator Mutator::PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Mutator(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMutatorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MutatorNode); +TVM_REGISTER_NODE_TYPE(PyMutatorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") + .set_body_method(&MutatorNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply") + .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional { + TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); + return self->Apply(trace, &seed_); + }); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc new file mode 100644 index 000000000000..ff069e2c68cb --- /dev/null +++ b/src/meta_schedule/postproc/postproc.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Postproc Postproc::PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Postproc(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyPostprocNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(PostprocNode); +TVM_REGISTER_NODE_TYPE(PyPostprocNode); + +TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") + .set_body_method(&PostprocNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc new file mode 100644 index 000000000000..7592a8a2418e --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using space generator. */ +class ReplayFuncNode : public SearchStrategyNode { + public: + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayFuncNode* self; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The post processors */ + Array postprocs_{nullptr}; + /*! \brief The space generator for measure candidates generation. */ + SpaceGenerator space_generator_{nullptr}; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `space_generator_` is not visited + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& context) final { + this->space_generator_ = context->space_generator.value(); + this->mod_ = context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); + this->postprocs_ = context->postprocs; + this->rand_state_ = ForkSeed(&context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + Array result; + for (int i = st; i < ed; i++) { + for (;;) { + Array schs = self->space_generator_->GenerateDesignSpace(self->mod_); + int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); + tir::Schedule sch = schs[design_space_index]; + sch->EnterPostproc(); + bool failed = false; + for (const Postproc& proc : self->postprocs_) { + if (!proc->Apply(sch)) { + failed = true; + break; + } + } + if (!failed) { + result.push_back(MeasureCandidate(sch, self->args_info_)); + break; + } + } + } + return result; +} + +inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayFuncNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") + .set_body_typed(SearchStrategy::ReplayFunc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 668fca9ecbbf..a4d32175eb0b 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -18,12 +18,11 @@ # pylint: disable=missing-function-docstring import sys import pytest -from typing import List - import tvm from tvm.meta_schedule import TuneContext from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.search_strategy import ( + ReplayFunc, ReplayTrace, SearchStrategy, ) @@ -75,17 +74,17 @@ def _schedule_matmul(sch: Schedule): sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) -@pytest.mark.parametrize("TestClass", [ReplayTrace]) +@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name num_trials_per_iter = 7 num_trials_total = 20 strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) - context.space_generator.initialize_with_tune_context(context) - spaces = context.space_generator.generate_design_space(context.mod) + tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) - strategy.initialize_with_tune_context(context) + strategy.initialize_with_tune_context(tune_context) strategy.pre_tuning(spaces) (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] @@ -100,7 +99,7 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(context, candidates, runner_results) + strategy.notify_runner_results(tune_context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [7, 7, 6]