Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] PyDatabase Complete Function Reload Support #12838

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/trace.h>

namespace tvm {
Expand Down Expand Up @@ -267,6 +268,33 @@ class PyDatabaseNode : public DatabaseNode {
* \return An Array of all the tuning records in the database.
*/
using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
/*!
* \brief The function type of `QueryTuningRecord` method.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The best record of the given workload; NullOpt if not found.
*/
using FQueryTuningRecord = runtime::TypedPackedFunc<Optional<TuningRecord>(
const IRModule&, const Target&, const String&)>;
/*!
* \brief The function type of `QuerySchedule` method.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The schedule in the best schedule of the given workload; NullOpt if not found.
*/
using FQuerySchedule = runtime::TypedPackedFunc<Optional<tir::Schedule>(
const IRModule&, const Target&, const String&)>;
/*!
* \brief The function type of `QueryIRModule` method.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
*/
using FQueryIRModule =
runtime::TypedPackedFunc<Optional<IRModule>(const IRModule&, const Target&, const String&)>;
/*!
* \brief The function type of `Size` method.
* \return The size of the database.
Expand All @@ -283,6 +311,12 @@ class PyDatabaseNode : public DatabaseNode {
FGetTopK f_get_top_k;
/*! \brief The packed function to the `GetAllTuningRecords` function. */
FGetAllTuningRecords f_get_all_tuning_records;
/*! \brief The packed function to the `QueryTuningRecord` function. */
FQueryTuningRecord f_query_tuning_record;
/*! \brief The packed function to the `QuerySchedule` function. */
FQuerySchedule f_query_schedule;
/*! \brief The packed function to the `QueryIRModule` function. */
FQueryIRModule f_query_ir_module;
/*! \brief The packed function to the `Size` function. */
FSize f_size;

Expand All @@ -295,6 +329,9 @@ class PyDatabaseNode : public DatabaseNode {
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
// `f_get_all_tuning_records` is not visited
// `f_query_tuning_record` is not visited
// `f_query_schedule` is not visited
// `f_query_ir_module` is not visited
// `f_size` is not visited
}

Expand Down Expand Up @@ -325,6 +362,33 @@ class PyDatabaseNode : public DatabaseNode {
return f_get_all_tuning_records();
}

Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
const String& workload_name) final {
if (f_query_tuning_record == nullptr) {
return DatabaseNode::QueryTuningRecord(mod, target, workload_name);
} else {
return f_query_tuning_record(mod, target, workload_name);
}
}

Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
const String& workload_name) final {
if (f_query_schedule == nullptr) {
return DatabaseNode::QuerySchedule(mod, target, workload_name);
} else {
return f_query_schedule(mod, target, workload_name);
}
}

Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
const String& workload_name) final {
if (f_query_ir_module == nullptr) {
return DatabaseNode::QueryIRModule(mod, target, workload_name);
} else {
return f_query_ir_module(mod, target, workload_name);
}
}

int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
Expand Down Expand Up @@ -380,6 +444,9 @@ class Database : public runtime::ObjectRef {
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
* \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`.
* \param f_query_tuning_record The packed function of `QueryTuningRecord`.
* \param f_query_schedule The packed function of `QuerySchedule`.
* \param f_query_ir_module The packed function of `QueryIRModule`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
Expand All @@ -388,6 +455,9 @@ class Database : public runtime::ObjectRef {
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
PyDatabaseNode::FQuerySchedule f_query_schedule,
PyDatabaseNode::FQueryIRModule f_query_ir_module,
PyDatabaseNode::FSize f_size);
/*! \return The current Database in the scope. */
static Optional<Database> Current();
Expand Down
81 changes: 81 additions & 0 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ def __init__(
f_commit_tuning_record: Callable = None,
f_get_top_k: Callable = None,
f_get_all_tuning_records: Callable = None,
f_query_tuning_record: Callable = None,
f_query_schedule: Callable = None,
f_query_ir_module: Callable = None,
f_size: Callable = None,
):
"""Constructor."""
Expand All @@ -389,6 +392,9 @@ def __init__(
f_commit_tuning_record,
f_get_top_k,
f_get_all_tuning_records,
f_query_tuning_record,
f_query_schedule,
f_query_ir_module,
f_size,
)

Expand All @@ -409,6 +415,9 @@ class PyDatabase:
"commit_tuning_record",
"get_top_k",
"get_all_tuning_records",
"query_tuning_record",
"query_schedule",
"query_ir_module",
"__len__",
],
}
Expand Down Expand Up @@ -478,6 +487,78 @@ def get_all_tuning_records(self) -> List[TuningRecord]:
"""
raise NotImplementedError

def query_tuning_record(
self, mod: IRModule, target: Target, workload_name: Optional[str] = None
) -> Optional[TuningRecord]:
"""Query a tuning record from the database.

Parameters
----------
mod : IRModule
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : Optional[str]
The workload name to be searched for.

Returns
-------
record : Optional[TuningRecord]
The tuning record corresponding to the given workload.
"""
# Using self._outer to replace the self pointer
return _ffi_api.DatabaseQueryTuningRecord( # type: ignore # pylint: disable=no-member
self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member
)

def query_schedule(
self, mod: IRModule, target: Target, workload_name: Optional[str] = None
) -> Optional[Schedule]:
"""Query a schedule from the database.

Parameters
----------
mod : IRModule
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : Optional[str]
The workload name to be searched for.

Returns
-------
schedule : Optional[Schedule]
The schedule corresponding to the given workload.
"""
# Using self._outer to replace the self pointer
return _ffi_api.DatabaseQuerySchedule( # type: ignore # pylint: disable=no-member
self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member
)

def query_ir_module(
self, mod: IRModule, target: Target, workload_name: Optional[str] = None
) -> Optional[IRModule]:
"""Query an IRModule from the database.

Parameters
----------
mod : IRModule
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : Optional[str]
The workload name to be searched for.

Returns
-------
mod : Optional[IRModule]
The IRModule corresponding to the given workload.
"""
# Using self._outer to replace the self pointer
return _ffi_api.DatabaseQueryIRModule( # type: ignore # pylint: disable=no-member
self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member
)

def __len__(self) -> int:
"""Get the number of records in the database.

Expand Down
6 changes: 6 additions & 0 deletions src/meta_schedule/database/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,19 @@ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FQueryTuningRecord f_query_tuning_record,
PyDatabaseNode::FQuerySchedule f_query_schedule,
PyDatabaseNode::FQueryIRModule f_query_ir_module,
PyDatabaseNode::FSize f_size) {
ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>();
n->f_has_workload = f_has_workload;
n->f_commit_workload = f_commit_workload;
n->f_commit_tuning_record = f_commit_tuning_record;
n->f_get_top_k = f_get_top_k;
n->f_get_all_tuning_records = f_get_all_tuning_records;
n->f_query_tuning_record = f_query_tuning_record;
n->f_query_schedule = f_query_schedule;
n->f_query_ir_module = f_query_ir_module;
n->f_size = f_size;
return Database(n);
}
Expand Down
Loading