Skip to content

Commit

Permalink
Rebase.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Nov 4, 2021
1 parent b8ced4d commit 7b22b2e
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 21 deletions.
1 change: 1 addition & 0 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class PyMeasureCallbackNode : public MeasureCallbackNode {
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
}

Expand Down
6 changes: 1 addition & 5 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,7 @@ class TaskScheduler : public runtime::ObjectRef {
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
PyTaskSchedulerNode::FNextTaskId f_next_task_id, //
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database);
PyTaskSchedulerNode::FNextTaskId f_next_task_id);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
};

Expand Down
13 changes: 7 additions & 6 deletions python/tvm/meta_schedule/measure_callback/measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@

from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.search_strategy import MeasureCandidate
from tvm.meta_schedule.builder import BuilderResult
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.utils import _get_hex_address

from ..tune_context import TuneContext
from ..search_strategy import MeasureCandidate
from ..builder import BuilderResult
from ..runner import RunnerResult
from ..utils import _get_hex_address, check_override

from .. import _ffi_api

if TYPE_CHECKING:
from ..tune_context import TuneContext
from ..task_scheduler import TaskScheduler


Expand Down Expand Up @@ -77,6 +77,7 @@ class PyMeasureCallback(MeasureCallback):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, MeasureCallback)
def f_apply(
task_scheduler: "TaskScheduler",
tasks: List[TuneContext],
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/measure_callback/measure_callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
const auto* self = n.as<PyMeasureCallbackNode>();
ICHECK(self);
PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string;
ICHECK(f_as_string != nullptr);
ICHECK(f_as_string != nullptr) << "PyMeasureCallback's AsString method not implemented!";
p->stream << f_as_string();
});

Expand Down
10 changes: 1 addition & 9 deletions src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,7 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
PyTaskSchedulerNode::FNextTaskId f_next_task_id, //
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database) {
PyTaskSchedulerNode::FNextTaskId f_next_task_id) {
ObjectPtr<PyTaskSchedulerNode> n = make_object<PyTaskSchedulerNode>();
n->tasks = tasks;
n->builder = builder;
Expand All @@ -234,10 +230,6 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
n->f_is_task_running = f_is_task_running;
n->f_join_running_task = f_join_running_task;
n->f_next_task_id = f_next_task_id;
n->tasks = tasks;
n->builder = builder;
n->runner = runner;
n->database = database;
return TaskScheduler(n);
}

Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_meta_schedule_measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ def apply(

def test_meta_schedule_measure_callback_as_string():
class NotSoFancyMeasureCallback(PyMeasureCallback):
def apply(
self,
task_scheduler: "TaskScheduler",
tasks: List["TuneContext"],
measure_candidates: List[MeasureCandidate],
builds: List[BuilderResult],
results: List[RunnerResult],
) -> bool:
pass

def __str__(self) -> str:
return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})"

Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_meta_schedule_postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def apply(self, sch: Schedule) -> bool:

def test_meta_schedule_postproc_as_string():
class NotSoFancyPostproc(PyPostproc):
def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
pass

def apply(self, sch: Schedule) -> bool:
pass

def __str__(self) -> str:
return f"NotSoFancyPostproc({_get_hex_address(self.handle)})"

Expand Down

0 comments on commit 7b22b2e

Please sign in to comment.