Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fix Infinite Loop Caused When Calling Methods Not Overrided In PyClass #496

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/meta_schedule/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class PyBuilderNode : public BuilderNode {
}

Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) final {
ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!";
return f_build(build_inputs);
}

Expand Down
23 changes: 17 additions & 6 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,29 @@ class PyDatabaseNode : public DatabaseNode {
// `f_size` is not visited
}

static constexpr const char* _type_key = "meta_schedule.PyDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);

Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); }
Workload CommitWorkload(const IRModule& mod) final {
ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
return f_commit_workload(mod);
}

void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); }
void CommitTuningRecord(const TuningRecord& record) final {
ICHECK(f_commit_tuning_record != nullptr)
<< "PyDatabase's CommitTuningRecord method not implemented!";
f_commit_tuning_record(record);
}

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
return f_get_top_k(workload, top_k);
}

int64_t Size() final { return f_size(); }
int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
}

static constexpr const char* _type_key = "meta_schedule.PyDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);
};

/*!
Expand Down
1 change: 1 addition & 0 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class PyMeasureCallbackNode : public MeasureCallbackNode {
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
}

Expand Down
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ class PyMutatorNode : public MutatorNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyMutator's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

Optional<tir::Trace> Apply(const tir::Trace& trace) final { return this->f_apply(trace); }
Optional<tir::Trace> Apply(const tir::Trace& trace) final {
ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!";
return this->f_apply(trace);
}

static constexpr const char* _type_key = "meta_schedule.PyMutator";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,15 @@ class PyPostprocNode : public PostprocNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyPostproc's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

bool Apply(const tir::Schedule& sch) final { return this->f_apply(sch); }
bool Apply(const tir::Schedule& sch) final {
ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!";
return this->f_apply(sch);
}

static constexpr const char* _type_key = "meta_schedule.PyPostproc";
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
Expand Down
5 changes: 4 additions & 1 deletion include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ class PyRunnerNode : public RunnerNode {
// `f_run` is not visited
}

Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final { return f_run(runner_inputs); }
Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final {
ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!";
return f_run(runner_inputs);
}

static constexpr const char* _type_key = "meta_schedule.PyRunner";
TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode);
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,13 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyScheduleRule's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final {
ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!";
return this->f_apply(sch, block);
}

Expand Down
12 changes: 11 additions & 1 deletion include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,30 @@ class PySearchStrategyNode : public SearchStrategyNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySearchStrategy's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

void PreTuning(const Array<tir::Schedule>& design_spaces) final {
ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
this->f_pre_tuning(design_spaces);
}

void PostTuning() final { this->f_post_tuning(); }
void PostTuning() final {
ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!";
this->f_post_tuning();
}

Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
ICHECK(f_generate_measure_candidates != nullptr)
<< "PySearchStrategy's GenerateMeasureCandidates method not implemented!";
return this->f_generate_measure_candidates();
}

void NotifyRunnerResults(const Array<RunnerResult>& results) final {
ICHECK(f_notify_runner_results != nullptr)
<< "PySearchStrategy's NotifyRunnerResults method not implemented!";
this->f_notify_runner_results(results);
}

Expand Down
4 changes: 4 additions & 0 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,14 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
}

void InitializeWithTuneContext(const TuneContext& tune_context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySpaceGenerator's InitializeWithTuneContext !";
f_initialize_with_tune_context(tune_context);
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
ICHECK(f_generate_design_space != nullptr)
<< "PySpaceGenerator's GenerateDesignSpace method not implemented!";
return f_generate_design_space(mod);
}

Expand Down
33 changes: 28 additions & 5 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class TaskSchedulerNode : public runtime::Object {
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
};

class TaskScheduler;

/*! \brief The task scheduler with customized methods on the python-side. */
class PyTaskSchedulerNode : public TaskSchedulerNode {
public:
Expand Down Expand Up @@ -183,26 +185,47 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
}

void Tune() final { //
f_tune();
if (f_tune == nullptr) {
TaskSchedulerNode::Tune();
} else {
f_tune();
}
}

void InitializeTask(int task_id) final { //
f_initialize_task(task_id);
if (f_initialize_task == nullptr) {
TaskSchedulerNode::InitializeTask(task_id);
} else {
f_initialize_task(task_id);
}
}

void SetTaskStopped(int task_id) final { //
f_set_task_stopped(task_id);
if (f_set_task_stopped == nullptr) {
TaskSchedulerNode::SetTaskStopped(task_id);
} else {
f_set_task_stopped(task_id);
}
}

bool IsTaskRunning(int task_id) final { //
return f_is_task_running(task_id);
if (f_is_task_running == nullptr) {
return TaskSchedulerNode::IsTaskRunning(task_id);
} else {
return f_is_task_running(task_id);
}
}

void JoinRunningTask(int task_id) final { //
f_join_running_task(task_id);
if (f_join_running_task == nullptr) {
return TaskSchedulerNode::JoinRunningTask(task_id);
} else {
return f_join_running_task(task_id);
}
}

int NextTaskId() final { //
ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!";
return f_next_task_id();
}

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.target import Target

from .. import _ffi_api
from ..utils import check_override


@register_object("meta_schedule.BuilderInput")
Expand Down Expand Up @@ -119,6 +120,7 @@ class PyBuilder(Builder):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Builder)
def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]:
return self.build(build_inputs)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .. import _ffi_api
from ..arg_info import ArgInfo
from ..utils import _json_de_tvm
from ..utils import _json_de_tvm, check_override


@register_object("meta_schedule.Workload")
Expand Down Expand Up @@ -207,15 +207,19 @@ class PyDatabase(Database):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Database)
def f_commit_workload(mod: IRModule) -> Workload:
return self.commit_workload(mod)

@check_override(self.__class__, Database)
def f_commit_tuning_record(record: TuningRecord) -> None:
self.commit_tuning_record(record)

@check_override(self.__class__, Database)
def f_get_top_k(workload: Workload, top_k: int) -> List[TuningRecord]:
return self.get_top_k(workload, top_k)

@check_override(self.__class__, Database, func_name="__len__")
def f_size() -> int:
return len(self)

Expand All @@ -225,4 +229,4 @@ def f_size() -> int:
f_commit_tuning_record,
f_get_top_k,
f_size,
)
)
13 changes: 7 additions & 6 deletions python/tvm/meta_schedule/measure_callback/measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@

from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.search_strategy import MeasureCandidate
from tvm.meta_schedule.builder import BuilderResult
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.utils import _get_hex_address

from ..tune_context import TuneContext
from ..search_strategy import MeasureCandidate
from ..builder import BuilderResult
from ..runner import RunnerResult
from ..utils import _get_hex_address, check_override

from .. import _ffi_api

if TYPE_CHECKING:
from ..tune_context import TuneContext
from ..task_scheduler import TaskScheduler


Expand Down Expand Up @@ -77,6 +77,7 @@ class PyMeasureCallback(MeasureCallback):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, MeasureCallback)
def f_apply(
task_scheduler: "TaskScheduler",
tasks: List[TuneContext],
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.runtime import Object
from tvm.tir.schedule import Trace

from ..utils import _get_hex_address
from ..utils import _get_hex_address, check_override
from .. import _ffi_api

if TYPE_CHECKING:
Expand Down Expand Up @@ -66,9 +66,11 @@ class PyMutator(Mutator):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Mutator)
def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

@check_override(self.__class__, Mutator)
def f_apply(trace: Trace) -> Optional[Trace]:
return self.apply(trace)

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/postproc/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

from typing import TYPE_CHECKING

from tvm._ffi import register_object, register_func
from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.tir.schedule import Schedule
from tvm.meta_schedule.utils import _get_hex_address

from .. import _ffi_api
from ..utils import _get_hex_address, check_override

if TYPE_CHECKING:
from ..tune_context import TuneContext
Expand Down Expand Up @@ -75,9 +75,11 @@ class PyPostproc(Postproc):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Postproc)
def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

@check_override(self.__class__, Postproc)
def f_apply(sch: Schedule) -> bool:
return self.apply(sch)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .. import _ffi_api
from ..arg_info import ArgInfo
from ..utils import check_override


@register_object("meta_schedule.RunnerInput")
Expand Down Expand Up @@ -158,6 +159,7 @@ class PyRunner(Runner):
def __init__(self) -> None:
"""Constructor"""

@check_override(self.__class__, Runner)
def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
return self.run(runner_inputs)

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm.runtime import Object
from tvm.tir.schedule import Schedule, BlockRV

from ..utils import _get_hex_address
from ..utils import _get_hex_address, check_override
from .. import _ffi_api

if TYPE_CHECKING:
Expand Down Expand Up @@ -72,9 +72,11 @@ class PyScheduleRule(ScheduleRule):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, ScheduleRule)
def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

@check_override(self.__class__, ScheduleRule)
def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]:
return self.apply(sch, block)

Expand Down
Loading