Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule] Introduce ScheduleFnDatabase (apache#12626)
Browse files Browse the repository at this point in the history
Following apache#12520, this PR introduces `ScheduleFnDatabase`, a mocked
database to allow injecting handcrafted schedules provided by a schedule
function.

The schedule function comes with the following signature:

```python
def schedule_fn(
  sch: tir.Schedule,
) -> bool:
  task_name = sch.mod.attrs["task_name"]
  # ^^^ provides an optional name of the task queried
  ...
```

This mocked database helps incorporate the existing testing utility
`apply_fixed_schedule` more formally into the MetaSchedule-Relay build
pipeline, and allows further extension to Relax with the same interface.

Next as another follow-up, we will introduce ConcatDatabase that allows
mixing multiple databases, including the mocked and ones from JSON
files.
  • Loading branch information
junrushao authored and xinetzone committed Nov 25, 2022
1 parent 4f302ff commit a9f2bf3
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 135 deletions.
19 changes: 16 additions & 3 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,29 @@ class DatabaseNode : public runtime::Object {
* \brief Query the best record of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The best record of the given workload; NullOpt if not found.
*/
virtual Optional<TuningRecord> QueryTuningRecord(IRModule mod, Target target);
virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
const String& workload_name);
/*!
* \brief Query the best schedule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The schedule in the best schedule of the given workload; NullOpt if not found.
*/
virtual Optional<tir::Schedule> QuerySchedule(IRModule mod, Target target);
virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
const String& workload_name);
/*!
* \brief Query the best IRModule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
*/
virtual Optional<IRModule> QueryIRModule(IRModule mod, Target target);
virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
const String& workload_name);

static constexpr const char* _type_key = "meta_schedule.Database";
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object);
Expand Down Expand Up @@ -336,6 +342,13 @@ class Database : public runtime::ObjectRef {
public:
/*! An in-memory database. */
TVM_DLL static Database MemoryDatabase();
/*!
* \brief A database for injecting handcrafted schedule functions.
* \param schedule_fn The function to do scheduling, which takes a TIR schedule,
* and returns a boolean indicating if the schedule is successful.
*/
TVM_DLL static Database ScheduleFnDatabase(
runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn);
/*!
* \brief Create a default database that uses JSON file for tuning records.
* \param path_workload The path to the workload table.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
from .memory_database import MemoryDatabase
from .schedule_fn_database import ScheduleFnDatabase
41 changes: 32 additions & 9 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,12 @@ def __len__(self) -> int:
"""
return _ffi_api.DatabaseSize(self) # type: ignore # pylint: disable=no-member

def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningRecord]:
def query_tuning_record(
self,
mod: IRModule,
target: Target,
workload_name: str,
) -> Optional[TuningRecord]:
"""Query the best record of the given workload from the database.
Parameters
Expand All @@ -244,15 +249,22 @@ def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningR
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : str
The name of the workload to be searched for.
Returns
-------
tuning_record : Optional[TuningRecord]
The best record of the given workload; None if not found.
"""
return _ffi_api.DatabaseQueryTuningRecord(self, mod, target) # type: ignore # pylint: disable=no-member
return _ffi_api.DatabaseQueryTuningRecord(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member

def query_schedule(self, mod: IRModule, target: Target) -> Optional[Schedule]:
def query_schedule(
self,
mod: IRModule,
target: Target,
workload_name: str,
) -> Optional[Schedule]:
"""Query the best schedule of the given workload from the database.
Parameters
Expand All @@ -261,15 +273,22 @@ def query_schedule(self, mod: IRModule, target: Target) -> Optional[Schedule]:
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : str
The name of the workload to be searched for.
Returns
-------
schedule : Optional[Schedule]
The best schedule of the given workload; None if not found.
"""
return _ffi_api.DatabaseQuerySchedule(self, mod, target) # type: ignore # pylint: disable=no-member
return _ffi_api.DatabaseQuerySchedule(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member

def query_ir_module(self, mod: IRModule, target: Target) -> Optional[IRModule]:
def query_ir_module(
self,
mod: IRModule,
target: Target,
workload_name: str,
) -> Optional[IRModule]:
"""Query the best IRModule of the given workload from the database.
Parameters
Expand All @@ -278,18 +297,22 @@ def query_ir_module(self, mod: IRModule, target: Target) -> Optional[IRModule]:
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : str
The name of the workload to be searched for.
Returns
-------
ir_module : Optional[IRModule]
The best IRModule of the given workload; None if not found.
"""
return _ffi_api.DatabaseQueryIRModule(self, mod, target) # type: ignore # pylint: disable=no-member
return _ffi_api.DatabaseQueryIRModule(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member

def query(
self,
mod: IRModule,
target: Target,
*,
workload_name: str = "main",
kind: Union[
Literal["schedule"],
Literal["record"],
Expand All @@ -313,11 +336,11 @@ def query(
The best optimization outcome of the given workload.
"""
if kind == "schedule":
return self.query_schedule(mod, target)
return self.query_schedule(mod, target, workload_name)
if kind == "record":
return self.query_tuning_record(mod, target)
return self.query_tuning_record(mod, target, workload_name)
if kind == "ir_module":
return self.query_ir_module(mod, target)
return self.query_ir_module(mod, target, workload_name)
raise ValueError(f'Unknown kind: {kind}. Candidates are: "schedule", "record", "ir_module"')

def __enter__(self) -> "Database":
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/meta_schedule/database/schedule_fn_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A database for injecting handcrafted schedule functions."""
from typing import Callable

from tvm._ffi import register_object
from tvm.tir import Schedule

from .. import _ffi_api
from .database import Database


@register_object("meta_schedule.ScheduleFnDatabase")
class ScheduleFnDatabase(Database):
"""A database for injecting handcrafted schedule functions."""

def __init__(
self,
schedule_fn: Callable[[Schedule], bool],
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.DatabaseScheduleFnDatabase, # type: ignore # pylint: disable=no-member
schedule_fn,
)
83 changes: 0 additions & 83 deletions python/tvm/meta_schedule/testing/utils.py

This file was deleted.

13 changes: 8 additions & 5 deletions src/meta_schedule/database/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w

/******** Database ********/

Optional<TuningRecord> DatabaseNode::QueryTuningRecord(IRModule mod, Target target) {
Optional<TuningRecord> DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target,
const String& workload_name) {
if (!this->HasWorkload(mod)) {
return NullOpt;
}
Expand All @@ -168,8 +169,9 @@ Optional<TuningRecord> DatabaseNode::QueryTuningRecord(IRModule mod, Target targ
return records[0];
}

Optional<tir::Schedule> DatabaseNode::QuerySchedule(IRModule mod, Target target) {
if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target)) {
Optional<tir::Schedule> DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target,
const String& workload_name) {
if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target, workload_name)) {
TuningRecord record = opt_record.value();
tir::Schedule sch =
tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
Expand All @@ -181,8 +183,9 @@ Optional<tir::Schedule> DatabaseNode::QuerySchedule(IRModule mod, Target target)
}
}

Optional<IRModule> DatabaseNode::QueryIRModule(IRModule mod, Target target) {
if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target)) {
Optional<IRModule> DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target,
const String& workload_name) {
if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target, workload_name)) {
return opt_sch.value()->mod();
} else {
return NullOpt;
Expand Down
10 changes: 5 additions & 5 deletions src/meta_schedule/database/memory_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class MemoryDatabaseNode : public DatabaseNode {
return false;
}

Workload CommitWorkload(const IRModule& mod) {
Workload CommitWorkload(const IRModule& mod) final {
for (const auto& workload : workloads) {
if (StructuralEqual()(workload->mod, mod)) {
return workload;
Expand All @@ -55,9 +55,9 @@ class MemoryDatabaseNode : public DatabaseNode {
return workload;
}

void CommitTuningRecord(const TuningRecord& record) { records.push_back(record); }
void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); }

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) {
Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
std::vector<std::pair<double, TuningRecord>> results;
results.reserve(this->records.size());
for (const TuningRecord& record : records) {
Expand Down Expand Up @@ -91,9 +91,9 @@ class MemoryDatabaseNode : public DatabaseNode {
return ret;
}

Array<TuningRecord> GetAllTuningRecords() { return records; }
Array<TuningRecord> GetAllTuningRecords() final { return records; }

int64_t Size() { return records.size(); }
int64_t Size() final { return records.size(); }
};

Database Database::MemoryDatabase() {
Expand Down
Loading

0 comments on commit a9f2bf3

Please sign in to comment.