Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule] PyDatabase Complete Function Reload Support (apache#12838
Browse files Browse the repository at this point in the history
)

* Save for PR.

* Fix database default query function call.

* Add test.

* Fix lint.

* Remove unused import.

* Differentiate override class.

* Reuse outer class functions.

* Fix lint.
  • Loading branch information
zxybazh authored and xinetzone committed Nov 25, 2022
1 parent b6df0d3 commit a26904d
Show file tree
Hide file tree
Showing 4 changed files with 365 additions and 3 deletions.
70 changes: 70 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/trace.h>

namespace tvm {
Expand Down Expand Up @@ -267,6 +268,33 @@ class PyDatabaseNode : public DatabaseNode {
* \return An Array of all the tuning records in the database.
*/
using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
/*!
* \brief The function type of `QueryTuningRecord` method.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The best record of the given workload; NullOpt if not found.
*/
using FQueryTuningRecord = runtime::TypedPackedFunc<Optional<TuningRecord>(
const IRModule&, const Target&, const String&)>;
/*!
* \brief The function type of `QuerySchedule` method.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The schedule in the best schedule of the given workload; NullOpt if not found.
*/
using FQuerySchedule = runtime::TypedPackedFunc<Optional<tir::Schedule>(
const IRModule&, const Target&, const String&)>;
/*!
* \brief The function type of `QueryIRModule` method.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
*/
using FQueryIRModule =
runtime::TypedPackedFunc<Optional<IRModule>(const IRModule&, const Target&, const String&)>;
/*!
* \brief The function type of `Size` method.
* \return The size of the database.
Expand All @@ -283,6 +311,12 @@ class PyDatabaseNode : public DatabaseNode {
FGetTopK f_get_top_k;
/*! \brief The packed function to the `GetAllTuningRecords` function. */
FGetAllTuningRecords f_get_all_tuning_records;
/*! \brief The packed function to the `QueryTuningRecord` function. */
FQueryTuningRecord f_query_tuning_record;
/*! \brief The packed function to the `QuerySchedule` function. */
FQuerySchedule f_query_schedule;
/*! \brief The packed function to the `QueryIRModule` function. */
FQueryIRModule f_query_ir_module;
/*! \brief The packed function to the `Size` function. */
FSize f_size;

Expand All @@ -295,6 +329,9 @@ class PyDatabaseNode : public DatabaseNode {
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
// `f_get_all_tuning_records` is not visited
// `f_query_tuning_record` is not visited
// `f_query_schedule` is not visited
// `f_query_ir_module` is not visited
// `f_size` is not visited
}

Expand Down Expand Up @@ -325,6 +362,33 @@ class PyDatabaseNode : public DatabaseNode {
return f_get_all_tuning_records();
}

Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
const String& workload_name) final {
if (f_query_tuning_record == nullptr) {
return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
} else {
return f_query_tuning_record(mod, target, workload_name);
}
}

Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
const String& workload_name) final {
if (f_query_schedule == nullptr) {
return DatabaseNode::QuerySchedule(mod, target, workload_name);
} else {
return f_query_schedule(mod, target, workload_name);
}
}

Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
const String& workload_name) final {
if (f_query_ir_module == nullptr) {
return DatabaseNode::QueryIRModule(mod, target, workload_name);
} else {
return f_query_ir_module(mod, target, workload_name);
}
}

int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
Expand Down Expand Up @@ -380,6 +444,9 @@ class Database : public runtime::ObjectRef {
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
* \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`.
* \param f_query_tuning_record The packed function of `QueryTuningRecord`.
* \param f_query_schedule The packed function of `QuerySchedule`.
* \param f_query_ir_module The packed function of `QueryIRModule`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
Expand All @@ -388,6 +455,9 @@ class Database : public runtime::ObjectRef {
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
PyDatabaseNode::FQuerySchedule f_query_schedule,
PyDatabaseNode::FQueryIRModule f_query_ir_module,
PyDatabaseNode::FSize f_size);
/*! \return The current Database in the scope. */
static Optional<Database> Current();
Expand Down
81 changes: 81 additions & 0 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ def __init__(
f_commit_tuning_record: Callable = None,
f_get_top_k: Callable = None,
f_get_all_tuning_records: Callable = None,
f_query_tuning_record: Callable = None,
f_query_schedule: Callable = None,
f_query_ir_module: Callable = None,
f_size: Callable = None,
):
"""Constructor."""
Expand All @@ -389,6 +392,9 @@ def __init__(
f_commit_tuning_record,
f_get_top_k,
f_get_all_tuning_records,
f_query_tuning_record,
f_query_schedule,
f_query_ir_module,
f_size,
)

Expand All @@ -409,6 +415,9 @@ class PyDatabase:
"commit_tuning_record",
"get_top_k",
"get_all_tuning_records",
"query_tuning_record",
"query_schedule",
"query_ir_module",
"__len__",
],
}
Expand Down Expand Up @@ -478,6 +487,78 @@ def get_all_tuning_records(self) -> List[TuningRecord]:
"""
raise NotImplementedError

def query_tuning_record(
self, mod: IRModule, target: Target, workload_name: Optional[str] = None
) -> Optional[TuningRecord]:
"""Query a tuning record from the database.
Parameters
----------
mod : IRModule
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : Optional[str]
The workload name to be searched for.
Returns
-------
record : Optional[TuningRecord]
The tuning record corresponding to the given workload.
"""
# Using self._outer to replace the self pointer
return _ffi_api.DatabaseQueryTuningRecord( # type: ignore # pylint: disable=no-member
self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member
)

def query_schedule(
self, mod: IRModule, target: Target, workload_name: Optional[str] = None
) -> Optional[Schedule]:
"""Query a schedule from the database.
Parameters
----------
mod : IRModule
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : Optional[str]
The workload name to be searched for.
Returns
-------
schedule : Optional[Schedule]
The schedule corresponding to the given workload.
"""
# Using self._outer to replace the self pointer
return _ffi_api.DatabaseQuerySchedule( # type: ignore # pylint: disable=no-member
self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member
)

def query_ir_module(
self, mod: IRModule, target: Target, workload_name: Optional[str] = None
) -> Optional[IRModule]:
"""Query an IRModule from the database.
Parameters
----------
mod : IRModule
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : Optional[str]
The workload name to be searched for.
Returns
-------
mod : Optional[IRModule]
The IRModule corresponding to the given workload.
"""
# Using self._outer to replace the self pointer
return _ffi_api.DatabaseQueryIRModule( # type: ignore # pylint: disable=no-member
self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member
)

def __len__(self) -> int:
"""Get the number of records in the database.
Expand Down
6 changes: 6 additions & 0 deletions src/meta_schedule/database/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,19 @@ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
PyDatabaseNode::FQuerySchedule f_query_schedule,
PyDatabaseNode::FQueryIRModule f_query_ir_module,
PyDatabaseNode::FSize f_size) {
ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>();
n->f_has_workload = f_has_workload;
n->f_commit_workload = f_commit_workload;
n->f_commit_tuning_record = f_commit_tuning_record;
n->f_get_top_k = f_get_top_k;
n->f_get_all_tuning_records = f_get_all_tuning_records;
n->f_query_tuning_record = f_query_tuning_record;
n->f_query_schedule = f_query_schedule;
n->f_query_ir_module = f_query_ir_module;
n->f_size = f_size;
return Database(n);
}
Expand Down
Loading

0 comments on commit a26904d

Please sign in to comment.