Skip to content

Commit

Permalink
[M3c][Meta Schedule] Implement the Replay Func class. (#495)
Browse files Browse the repository at this point in the history
* Add replay func.

* Modify pretuning.

* Modify docs.

* Minor nit.

* Enhance unittest.

* Nit.

* Fix replay func.

* Simplify unittest.

* Fix PyClass reflection.
  • Loading branch information
zxybazh authored Oct 28, 2021
1 parent 17a2180 commit ca3f72c
Show file tree
Hide file tree
Showing 16 changed files with 230 additions and 101 deletions.
8 changes: 8 additions & 0 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/tir/schedule/schedule.h>

namespace tvm {
Expand Down Expand Up @@ -237,6 +238,13 @@ class SearchStrategy : public runtime::ObjectRef {
*/
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total);

/*!
* \brief Constructor of replay func search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param num_trials_total The total number of trials for func replaying.
*/
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
};

Expand Down
3 changes: 0 additions & 3 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,3 @@ def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]:
_ffi_api.BuilderPyBuilder, # type: ignore # pylint: disable=no-member
f_build,
)

def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
raise NotImplementedError
14 changes: 1 addition & 13 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,4 @@ def f_size() -> int:
f_commit_tuning_record,
f_get_top_k,
f_size,
)

def commit_workload(self, mod: IRModule) -> Workload:
raise NotImplementedError

def commit_tuning_record(self, record: TuningRecord) -> None:
raise NotImplementedError

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
raise NotImplementedError

def __len__(self) -> int:
raise NotImplementedError
)
6 changes: 0 additions & 6 deletions python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,5 @@ def f_as_string() -> str:
f_as_string,
)

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
raise NotImplementedError

def apply(self, trace: Trace) -> Optional[Trace]:
raise NotImplementedError

def __str__(self) -> str:
return f"PyMutator({_get_hex_address(self.handle)})"
6 changes: 0 additions & 6 deletions python/tvm/meta_schedule/postproc/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,5 @@ def f_as_string() -> str:
f_as_string,
)

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
raise NotImplementedError

def apply(self, sch: Schedule) -> bool:
raise NotImplementedError

def __str__(self) -> str:
return f"PyPostproc({_get_hex_address(self.handle)})"
3 changes: 0 additions & 3 deletions python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,3 @@ def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
_ffi_api.RunnerPyRunner, # type: ignore # pylint: disable=no-member
f_run,
)

def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
raise NotImplementedError
6 changes: 0 additions & 6 deletions python/tvm/meta_schedule/schedule_rule/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,5 @@ def f_as_string() -> str:
f_as_string,
)

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
raise NotImplementedError

def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
raise NotImplementedError

def __str__(self) -> str:
return f"PyScheduleRule({_get_hex_address(self.handle)})"
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@

from .search_strategy import SearchStrategy, PySearchStrategy
from .replay_trace import ReplayTrace
from .replay_func import ReplayFunc
51 changes: 51 additions & 0 deletions python/tvm/meta_schedule/search_strategy/replay_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Replay Trace Search Strategy"""

from tvm._ffi import register_object
from .search_strategy import SearchStrategy
from .. import _ffi_api


@register_object("meta_schedule.ReplayFunc")
class ReplayFunc(SearchStrategy):
"""
Replay Func Search Strategy is a search strategy that generates measure candidates by
calling a design space generator and transform the design space.
Parameters
----------
num_trials_per_iter : int
Number of trials per iteration.
num_trials_total : int
Total number of trials.
"""

num_trials_per_iter: int
num_trials_total: int

def __init__(
self,
num_trials_per_iter: int,
num_trials_total: int,
):
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyReplayFunc, # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
)
15 changes: 0 additions & 15 deletions python/tvm/meta_schedule/search_strategy/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,3 @@ def f_notify_runner_results(results: List["RunnerResult"]) -> None:
f_generate_measure_candidates,
f_notify_runner_results,
)

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
raise NotImplementedError

def pre_tuning(self, design_spaces: List[Schedule]) -> None:
raise NotImplementedError

def post_tuning(self) -> None:
raise NotImplementedError

def generate_measure_candidates(self) -> List[MeasureCandidate]:
raise NotImplementedError

def notify_runner_results(self, results: List["RunnerResult"]) -> None:
raise NotImplementedError
6 changes: 0 additions & 6 deletions python/tvm/meta_schedule/space_generator/space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,3 @@ def f_generate_design_space(mod: IRModule) -> List[Schedule]:
f_initialize_with_tune_context,
f_generate_design_space,
)

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
raise NotImplementedError

def generate_design_space(self, mod) -> List[Schedule]:
raise NotImplementedError
18 changes: 0 additions & 18 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,3 @@ def f_next_task_id() -> int:
f_join_running_task,
f_next_task_id,
)

def tune(self) -> None:
raise NotImplementedError()

def _initialize_task(self, task_id: int) -> None:
raise _ffi_api.TaskSchedulerInitializeTask(self, task_id)

def _set_task_stopped(self, task_id: int) -> None:
_ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member

def _is_task_running(self, task_id: int) -> bool:
return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member

def _join_running_task(self, task_id: int) -> None:
_ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member

def _next_task_id(self) -> int:
return _ffi_api.TaskSchedulerNextTaskId(self) # pylint: disable=no-member
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class TuneContext(Object):

mod: Optional[IRModule]
target: Optional[Target]
space_generator: "SpaceGenerator"
search_strategy: "SearchStrategy"
space_generator: Optional["SpaceGenerator"]
search_strategy: Optional["SearchStrategy"]
sch_rules: List["ScheduleRule"]
postproc: List["Postproc"]
mutator: List["Mutator"]
Expand Down
134 changes: 134 additions & 0 deletions src/meta_schedule/search_strategy/replay_func.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

/*! \brief A search strategy that generates measure candidates using space generator. */
class ReplayFuncNode : public SearchStrategyNode {
public:
using TRandState = support::LinearCongruentialEngine::TRandState;

/*! \brief The state of the search strategy. */
struct State {
/*! \brief The search strategy itself */
ReplayFuncNode* self;
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int st;
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int ed;

explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {}

inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
inline void NotifyRunnerResults(const Array<RunnerResult>& results);
};

/*! \brief The number of trials per iteration. */
int num_trials_per_iter;
/*! \brief The number of total trials. */
int num_trials_total;

/*! \brief The module to be tuned. */
IRModule mod_{nullptr};
/*! \brief The metadata of the function arguments. */
Array<ArgInfo> args_info_{nullptr};
/*! \brief The space generator for measure candidates generation. */
SpaceGenerator space_generator_{nullptr};
/*! \brief The random state. -1 means using random number. */
TRandState rand_state_ = -1;
/*! \brief The state of the search strategy. */
std::unique_ptr<State> state_ = nullptr;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_trials_per_iter", &num_trials_per_iter);
v->Visit("num_trials_total", &num_trials_total);
// `space_generator_` is not visited
// `mod_` is not visited
// `args_info_` is not visited
// `num_threads_` is not visited
// `rand_state_` is not visited
// `state_` is not visited
}

static constexpr const char* _type_key = "meta_schedule.ReplayFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode);

void InitializeWithTuneContext(const TuneContext& tune_context) final {
this->space_generator_ = tune_context->space_generator.value();
this->mod_ = tune_context->mod.value();
this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value()));
this->rand_state_ = ForkSeed(&tune_context->rand_state);
this->state_.reset();
}

void PreTuning(const Array<tir::Schedule>& design_spaces) final {
ICHECK(this->state_ == nullptr);
this->state_ = std::make_unique<State>(this);
}

void PostTuning() final {
ICHECK(this->state_ != nullptr);
this->state_.reset();
}

Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
ICHECK(this->state_ != nullptr);
return this->state_->GenerateMeasureCandidates();
}

void NotifyRunnerResults(const Array<RunnerResult>& results) final {
ICHECK(this->state_ != nullptr);
this->state_->NotifyRunnerResults(results);
}
};

inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureCandidates() {
if (st >= self->num_trials_total) {
return NullOpt;
}
ed = std::min(ed, self->num_trials_total);
Array<MeasureCandidate> result;
for (int i = st; i < ed; i++) {
Array<tir::Schedule> schs = self->space_generator_->GenerateDesignSpace(self->mod_);
result.push_back(MeasureCandidate(schs[tir::SampleInt(&self->rand_state_, 0, schs.size())],
self->args_info_));
}
return result;
}

inline void ReplayFuncNode::State::NotifyRunnerResults(const Array<RunnerResult>& results) {
st += self->num_trials_per_iter;
ed += self->num_trials_per_iter;
}

SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) {
ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
n->num_trials_per_iter = num_trials_per_iter;
n->num_trials_total = num_trials_total;
return SearchStrategy(n);
}

TVM_REGISTER_NODE_TYPE(ReplayFuncNode);
TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc")
.set_body_typed(SearchStrategy::ReplayFunc);

} // namespace meta_schedule
} // namespace tvm
2 changes: 1 addition & 1 deletion src/meta_schedule/search_strategy/replay_trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class ReplayTraceNode : public SearchStrategyNode {
this->mod_.push_back(DeepCopyIRModule(tune_context->mod.value()));
}

this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_[0]));
this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value()));
this->rand_state_ = ForkSeed(&tune_context->rand_state);
this->state_.reset();
}
Expand Down
Loading

0 comments on commit ca3f72c

Please sign in to comment.