Skip to content

Commit

Permalink
User-Interface: Tune-TIR
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Dec 1, 2021
1 parent 1c35161 commit 74e8c74
Show file tree
Hide file tree
Showing 24 changed files with 1,147 additions and 128 deletions.
41 changes: 28 additions & 13 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,14 @@ class MeasureCallbackNode : public runtime::Object {
* \param task_scheduler The task scheduler.
* \param tasks The list of tune context to process.
* \param measure_candidates The measure candidates.
* \param builds The builder results by building the measure candidates.
* \param results The runner results by running the built measure candidates.
* \return Whether the measure callback was successfully applied.
* \param builder_results The builder results by building the measure candidates.
* \param runner_results The runner results by running the built measure candidates.
*/
virtual bool Apply(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
virtual void Apply(const TaskScheduler& task_scheduler, //
int task_id, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) = 0;
const Array<BuilderResult>& builder_results, //
const Array<RunnerResult>& runner_results) = 0;

static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
Expand All @@ -70,8 +69,8 @@ class PyMeasureCallbackNode : public MeasureCallbackNode {
* \return Whether the measure callback was successfully applied.
*/
using FApply =
runtime::TypedPackedFunc<bool(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
runtime::TypedPackedFunc<void(const TaskScheduler& task_scheduler, //
int task_id, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results)>;
Expand All @@ -91,13 +90,13 @@ class PyMeasureCallbackNode : public MeasureCallbackNode {
// `f_as_string` is not visited
}

bool Apply(const TaskScheduler& task_scheduler, //
const Array<TuneContext> tasks, //
void Apply(const TaskScheduler& task_scheduler, //
int task_id, //
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
return this->f_apply(task_scheduler, task_id, measure_candidates, builds, results);
}

static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
Expand All @@ -110,12 +109,28 @@ class PyMeasureCallbackNode : public MeasureCallbackNode {
*/
class MeasureCallback : public runtime::ObjectRef {
public:
/*!
* \brief Create a measure callback that adds the measurement results into the database
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback AddToDatabase();
/*!
* \brief Create a measure callback that adds the measurement results into the database
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback RemoveBuildArtifact();
/*!
* \brief Create a measure callback that echos the statistics of the tuning process to the console
* \param f_count_flops The function to count FLOPs
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback EchoStatistics();
/*!
* \brief Create a measure callback with customized methods on the python-side.
* \param f_apply The packed function of `Apply`.
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, //
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply,
PyMeasureCallbackNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
};
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ class TaskSchedulerNode : public runtime::Object {
/*! \brief The database of the scheduler. */
Database database{nullptr};
/*! \brief The list of measure callbacks of the scheduler. */
Optional<Array<MeasureCallback>> measure_callbacks;
Array<MeasureCallback> measure_callbacks;

/*! \brief The default desctructor. */
/*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand Down
11 changes: 8 additions & 3 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_

#include <tvm/ir/module.h>
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/mutator.h>
#include <tvm/meta_schedule/postproc.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/space_generator.h>
Expand Down Expand Up @@ -49,18 +51,20 @@ class TuneContextNode : public runtime::Object {
/*! \brief The mutators. */
Optional<Array<Mutator>> mutators;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
String task_name;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
int num_threads;

/*! \brief Whether the tuning task has been stopped or finished. */
bool is_stopped;
/*! \brief Packed functions to fetch the runner results asynchronously. */
Optional<Array<RunnerFuture>> runner_futures;
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;
/*! \brief The building results. */
Optional<Array<BuilderResult>> builder_results;
/*! \brief Packed functions to fetch the runner results asynchronously. */
Optional<Array<RunnerFuture>> runner_futures;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mod", &mod);
Expand All @@ -74,6 +78,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_stopped", &is_stopped);
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
}
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@
from . import integration
from . import feature_extractor
from . import cost_model
from .tune_context import TuneContext
from .search_strategy import MeasureCandidate
from .tune_context import TuneContext
from .tune import tune_tir
3 changes: 3 additions & 0 deletions python/tvm/meta_schedule/measure_callback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@
The tvm.meta_schedule.measure_callback package.
"""
from .measure_callback import MeasureCallback, PyMeasureCallback
from .add_to_database import AddToDatabase
from .echo_statistics import EchoStatistics
from .remove_build_artifact import RemoveBuildArtifact
28 changes: 28 additions & 0 deletions python/tvm/meta_schedule/measure_callback/add_to_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm._ffi import register_object

from .. import _ffi_api
from .measure_callback import MeasureCallback


@register_object("meta_schedule.AddToDatabase")
class AddToDatabase(MeasureCallback):
def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.MeasureCallbackAddToDatabase, # type: ignore # pylint: disable=no-member
)
28 changes: 28 additions & 0 deletions python/tvm/meta_schedule/measure_callback/echo_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm._ffi import register_object

from .. import _ffi_api
from .measure_callback import MeasureCallback


@register_object("meta_schedule.EchoStatistics")
class EchoStatistics(MeasureCallback):
def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.MeasureCallbackEchoStatistics, # type: ignore # pylint: disable=no-member
)
49 changes: 27 additions & 22 deletions python/tvm/meta_schedule/measure_callback/measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from tvm._ffi import register_object
from tvm.runtime import Object

from ..tune_context import TuneContext
from ..search_strategy import MeasureCandidate
from ..builder import BuilderResult
from ..runner import RunnerResult
Expand All @@ -40,33 +39,33 @@ class MeasureCallback(Object):
def apply(
self,
task_scheduler: "TaskScheduler",
tasks: List["TuneContext"],
task_id: int,
measure_candidates: List[MeasureCandidate],
builds: List[BuilderResult],
results: List[RunnerResult],
) -> bool:
builder_results: List[BuilderResult],
runner_results: List[RunnerResult],
) -> None:
"""Apply a measure callback to the given schedule.
Parameters
----------
task_scheduler: TaskScheduler
The task scheduler.
tasks: List[TuneContext]
The list of tune context to process.
measure_candidats: List[MeasureCandidate]
task_id: int
The task id.
measure_candidates: List[MeasureCandidate]
The measure candidates.
builds: List[BuilderResult]
builder_results: List[BuilderResult]
The builder results by building the measure candidates.
results: List[RunnerResult]
runner_results: List[RunnerResult]
The runner results by running the built measure candidates.
Returns
-------
result : bool
Whether the measure callback was successfully applied.
"""
return _ffi_api.MeasureCallbackApply(
self, task_scheduler, tasks, measure_candidates, builds, results
return _ffi_api.MeasureCallbackApply( # type: ignore # pylint: disable=no-member
self,
task_scheduler,
task_id,
measure_candidates,
builder_results,
runner_results,
)


Expand All @@ -80,12 +79,18 @@ def __init__(self):
@check_override(self.__class__, MeasureCallback)
def f_apply(
task_scheduler: "TaskScheduler",
tasks: List[TuneContext],
task_id: int,
measure_candidates: List[MeasureCandidate],
builds: List[BuilderResult],
results: List[RunnerResult],
) -> bool:
return self.apply(task_scheduler, tasks, measure_candidates, builds, results)
builder_results: List[BuilderResult],
runner_results: List[RunnerResult],
) -> None:
return self.apply(
task_scheduler,
task_id,
measure_candidates,
builder_results,
runner_results,
)

def f_as_string() -> str:
return str(self)
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/meta_schedule/measure_callback/remove_build_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm._ffi import register_object
from .measure_callback import MeasureCallback
from .. import _ffi_api


@register_object("meta_schedule.RemoveBuildArtifact")
class RemoveBuildArtifact(MeasureCallback):
def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.MeasureCallbackRemoveBuildArtifact, # type: ignore # pylint: disable=no-member
)
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
"""

from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
from .replay_trace import ReplayTrace
from .replay_func import ReplayFunc
from .replay_trace import ReplayTrace, ReplayTraceConfig
from .replay_func import ReplayFunc, ReplayFuncConfig
14 changes: 13 additions & 1 deletion python/tvm/meta_schedule/search_strategy/replay_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Replay Trace Search Strategy"""
from typing import NamedTuple

from tvm._ffi import register_object
from .search_strategy import SearchStrategy

from .. import _ffi_api
from .search_strategy import SearchStrategy


@register_object("meta_schedule.ReplayFunc")
Expand Down Expand Up @@ -49,3 +51,13 @@ def __init__(
num_trials_per_iter,
num_trials_total,
)


class ReplayFuncConfig(NamedTuple):
"""Configuration for ReplayFunc"""

num_trials_per_iter: int
num_trials_total: int

def create_strategy(self) -> ReplayFunc:
return ReplayFunc(self.num_trials_per_iter, self.num_trials_total)
11 changes: 11 additions & 0 deletions python/tvm/meta_schedule/search_strategy/replay_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Replay Trace Search Strategy"""
from typing import NamedTuple

from tvm._ffi import register_object
from .search_strategy import SearchStrategy
Expand Down Expand Up @@ -45,3 +46,13 @@ def __init__(self, num_trials_per_iter: int, num_trials_total: int):
num_trials_per_iter,
num_trials_total,
)


class ReplayTraceConfig(NamedTuple):
"""Configuration for ReplayTrace"""

num_trials_per_iter: int
num_trials_total: int

def create_strategy(self) -> ReplayTrace:
return ReplayTrace(self.num_trials_per_iter, self.num_trials_total)
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def next_task_id(self) -> int:
Returns
-------
int
next_task_id : int
The next task id.
"""
return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member
Expand Down Expand Up @@ -98,7 +98,7 @@ def _is_task_running(self, task_id: int) -> bool:
Returns
-------
bool
running : bool
Whether the task is running.
"""
return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member
Expand Down
Loading

0 comments on commit 74e8c74

Please sign in to comment.