Skip to content

Commit

Permalink
[M3c][MetaScheduler] Add ReplayFunc Search Strategy. (apache#9799)
Browse files Browse the repository at this point in the history
* Modify TuneContext, TaskScheduler & SearchStrategy functions.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

* Retrigger CI.

* Add ReplayFunc and EvolutionarySearch strategy.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

* Fix optional task name.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

* Remove extra files.

* Fix things.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
7 people authored and ylc committed Jan 7, 2022
1 parent 02d754a commit d731916
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 10 deletions.
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@
Meta Schedule search strategy utilizes the design spaces given
to generate measure candidates.
"""
from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy
from .replay_trace import ReplayTrace

from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
from .replay_trace import ReplayTrace, ReplayTraceConfig
from .replay_func import ReplayFunc, ReplayFuncConfig
63 changes: 63 additions & 0 deletions python/tvm/meta_schedule/search_strategy/replay_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Replay Trace Search Strategy"""
from typing import NamedTuple

from tvm._ffi import register_object

from .. import _ffi_api
from .search_strategy import SearchStrategy


@register_object("meta_schedule.ReplayFunc")
class ReplayFunc(SearchStrategy):
"""
Replay Func Search Strategy is a search strategy that generates measure candidates by
calling a design space generator and transform the design space.
Parameters
----------
num_trials_per_iter : int
Number of trials per iteration.
num_trials_total : int
Total number of trials.
"""

num_trials_per_iter: int
num_trials_total: int

def __init__(
self,
num_trials_per_iter: int,
num_trials_total: int,
):
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint: disable=no-member
num_trials_per_iter,
num_trials_total,
)


class ReplayFuncConfig(NamedTuple):
"""Configuration for ReplayFunc"""

num_trials_per_iter: int
num_trials_total: int

def create_strategy(self) -> ReplayFunc:
return ReplayFunc(self.num_trials_per_iter, self.num_trials_total)
57 changes: 57 additions & 0 deletions src/meta_schedule/mutator/mutator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

Mutator Mutator::PyMutator(
PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyMutatorNode::FApply f_apply, //
PyMutatorNode::FAsString f_as_string) {
ObjectPtr<PyMutatorNode> n = make_object<PyMutatorNode>();
n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
n->f_apply = std::move(f_apply);
n->f_as_string = std::move(f_as_string);
return Mutator(n);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PyMutatorNode>([](const ObjectRef& n, ReprPrinter* p) {
const auto* self = n.as<PyMutatorNode>();
ICHECK(self);
PyMutatorNode::FAsString f_as_string = (*self).f_as_string;
ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!";
p->stream << f_as_string();
});

TVM_REGISTER_OBJECT_TYPE(MutatorNode);
TVM_REGISTER_NODE_TYPE(PyMutatorNode);

TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext")
.set_body_method<Mutator>(&MutatorNode::InitializeWithTuneContext);
TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply")
.set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional<tir::Trace> {
TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom();
return self->Apply(trace, &seed_);
});
TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator);

} // namespace meta_schedule
} // namespace tvm
53 changes: 53 additions & 0 deletions src/meta_schedule/postproc/postproc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

Postproc Postproc::PyPostproc(
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyPostprocNode::FApply f_apply, //
PyPostprocNode::FAsString f_as_string) {
ObjectPtr<PyPostprocNode> n = make_object<PyPostprocNode>();
n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
n->f_apply = std::move(f_apply);
n->f_as_string = std::move(f_as_string);
return Postproc(n);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PyPostprocNode>([](const ObjectRef& n, ReprPrinter* p) {
const auto* self = n.as<PyPostprocNode>();
ICHECK(self);
PyPostprocNode::FAsString f_as_string = (*self).f_as_string;
ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!";
p->stream << f_as_string();
});

TVM_REGISTER_OBJECT_TYPE(PostprocNode);
TVM_REGISTER_NODE_TYPE(PyPostprocNode);

TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext")
.set_body_method<Postproc>(&PostprocNode::InitializeWithTuneContext);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method<Postproc>(&PostprocNode::Apply);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc);

} // namespace meta_schedule
} // namespace tvm
151 changes: 151 additions & 0 deletions src/meta_schedule/search_strategy/replay_func.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

/*! \brief A search strategy that generates measure candidates using space generator. */
class ReplayFuncNode : public SearchStrategyNode {
public:
/*! \brief The state of the search strategy. */
struct State {
/*! \brief The search strategy itself */
ReplayFuncNode* self;
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int st;
/*! \brief `[st, ed)` are the indices of the next batch of candidates. */
int ed;

explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {}

inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
inline void NotifyRunnerResults(const Array<RunnerResult>& results);
};

/*! \brief The number of trials per iteration. */
int num_trials_per_iter;
/*! \brief The number of total trials. */
int num_trials_total;

/*! \brief The module to be tuned. */
IRModule mod_{nullptr};
/*! \brief The metadata of the function arguments. */
Array<ArgInfo> args_info_{nullptr};
/*! \brief The post processors */
Array<Postproc> postprocs_{nullptr};
/*! \brief The space generator for measure candidates generation. */
SpaceGenerator space_generator_{nullptr};
/*! \brief The random state. -1 means using random number. */
TRandState rand_state_ = -1;
/*! \brief The state of the search strategy. */
std::unique_ptr<State> state_ = nullptr;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_trials_per_iter", &num_trials_per_iter);
v->Visit("num_trials_total", &num_trials_total);
// `space_generator_` is not visited
// `mod_` is not visited
// `args_info_` is not visited
// `num_threads_` is not visited
// `rand_state_` is not visited
// `state_` is not visited
}

static constexpr const char* _type_key = "meta_schedule.ReplayFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode);

void InitializeWithTuneContext(const TuneContext& context) final {
this->space_generator_ = context->space_generator.value();
this->mod_ = context->mod.value();
this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
this->postprocs_ = context->postprocs;
this->rand_state_ = ForkSeed(&context->rand_state);
this->state_.reset();
}

void PreTuning(const Array<tir::Schedule>& design_spaces) final {
ICHECK(this->state_ == nullptr);
this->state_ = std::make_unique<State>(this);
}

void PostTuning() final {
ICHECK(this->state_ != nullptr);
this->state_.reset();
}

Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
ICHECK(this->state_ != nullptr);
return this->state_->GenerateMeasureCandidates();
}

void NotifyRunnerResults(const TuneContext& context,
const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results) final {
ICHECK(this->state_ != nullptr);
this->state_->NotifyRunnerResults(results);
}
};

inline Optional<Array<MeasureCandidate>> ReplayFuncNode::State::GenerateMeasureCandidates() {
if (st >= self->num_trials_total) {
return NullOpt;
}
ed = std::min(ed, self->num_trials_total);
Array<MeasureCandidate> result;
for (int i = st; i < ed; i++) {
for (;;) {
Array<tir::Schedule> schs = self->space_generator_->GenerateDesignSpace(self->mod_);
int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size());
tir::Schedule sch = schs[design_space_index];
sch->EnterPostproc();
bool failed = false;
for (const Postproc& proc : self->postprocs_) {
if (!proc->Apply(sch)) {
failed = true;
break;
}
}
if (!failed) {
result.push_back(MeasureCandidate(sch, self->args_info_));
break;
}
}
}
return result;
}

inline void ReplayFuncNode::State::NotifyRunnerResults(const Array<RunnerResult>& results) {
st += self->num_trials_per_iter;
ed += self->num_trials_per_iter;
}

SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) {
ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
n->num_trials_per_iter = num_trials_per_iter;
n->num_trials_total = num_trials_total;
return SearchStrategy(n);
}

TVM_REGISTER_NODE_TYPE(ReplayFuncNode);
TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc")
.set_body_typed(SearchStrategy::ReplayFunc);

} // namespace meta_schedule
} // namespace tvm
15 changes: 7 additions & 8 deletions tests/python/unittest/test_meta_schedule_search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
# pylint: disable=missing-function-docstring
import sys
import pytest
from typing import List

import tvm
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.search_strategy import (
ReplayFunc,
ReplayTrace,
SearchStrategy,
)
Expand Down Expand Up @@ -75,17 +74,17 @@ def _schedule_matmul(sch: Schedule):
sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)


@pytest.mark.parametrize("TestClass", [ReplayTrace])
@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace])
def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name
num_trials_per_iter = 7
num_trials_total = 20

strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total)
context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul))
context.space_generator.initialize_with_tune_context(context)
spaces = context.space_generator.generate_design_space(context.mod)
tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul))
tune_context.space_generator.initialize_with_tune_context(tune_context)
spaces = tune_context.space_generator.generate_design_space(tune_context.mod)

strategy.initialize_with_tune_context(context)
strategy.initialize_with_tune_context(tune_context)
strategy.pre_tuning(spaces)
(correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
num_trials_each_iter: List[int] = []
Expand All @@ -100,7 +99,7 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl
remove_decisions=(isinstance(strategy, ReplayTrace)),
)
runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None))
strategy.notify_runner_results(context, candidates, runner_results)
strategy.notify_runner_results(tune_context, candidates, runner_results)
candidates = strategy.generate_measure_candidates()
strategy.post_tuning()
assert num_trials_each_iter == [7, 7, 6]
Expand Down

0 comments on commit d731916

Please sign in to comment.