diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index fa488a38ce0a..4092fdae36dd 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace tvm { @@ -267,6 +268,33 @@ class PyDatabaseNode : public DatabaseNode { * \return An Array of all the tuning records in the database. */ using FGetAllTuningRecords = runtime::TypedPackedFunc()>; + /*! + * \brief The function type of `QueryTuningRecord` method. + * \param mod The IRModule to be searched for. + * \param target The target to be searched for. + * \param workload_name The name of the workload to be searched for. + * \return The best record of the given workload; NullOpt if not found. + */ + using FQueryTuningRecord = runtime::TypedPackedFunc( + const IRModule&, const Target&, const String&)>; + /*! + * \brief The function type of `QuerySchedule` method. + * \param mod The IRModule to be searched for. + * \param target The target to be searched for. + * \param workload_name The name of the workload to be searched for. + * \return The schedule in the best schedule of the given workload; NullOpt if not found. + */ + using FQuerySchedule = runtime::TypedPackedFunc( + const IRModule&, const Target&, const String&)>; + /*! + * \brief The function type of `QueryIRModule` method. + * \param mod The IRModule to be searched for. + * \param target The target to be searched for. + * \param workload_name The name of the workload to be searched for. + * \return The IRModule in the best IRModule of the given workload; NullOpt if not found. + */ + using FQueryIRModule = + runtime::TypedPackedFunc(const IRModule&, const Target&, const String&)>; /*! * \brief The function type of `Size` method. * \return The size of the database. @@ -283,6 +311,12 @@ class PyDatabaseNode : public DatabaseNode { FGetTopK f_get_top_k; /*! \brief The packed function to the `GetAllTuningRecords` function. */ FGetAllTuningRecords f_get_all_tuning_records; + /*! \brief The packed function to the `QueryTuningRecord` function. */ + FQueryTuningRecord f_query_tuning_record; + /*! \brief The packed function to the `QuerySchedule` function. */ + FQuerySchedule f_query_schedule; + /*! \brief The packed function to the `QueryIRModule` function. */ + FQueryIRModule f_query_ir_module; /*! \brief The packed function to the `Size` function. */ FSize f_size; @@ -295,6 +329,9 @@ class PyDatabaseNode : public DatabaseNode { // `f_commit_tuning_record` is not visited // `f_get_top_k` is not visited // `f_get_all_tuning_records` is not visited + // `f_query_tuning_record` is not visited + // `f_query_schedule` is not visited + // `f_query_ir_module` is not visited // `f_size` is not visited } @@ -325,6 +362,33 @@ class PyDatabaseNode : public DatabaseNode { return f_get_all_tuning_records(); } + Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const String& workload_name) final { + if (f_query_tuning_record == nullptr) { + return DatabaseNode::QueryTuningRecord(mod, target, workload_name); + } else { + return f_query_tuning_record(mod, target, workload_name); + } + } + + Optional QuerySchedule(const IRModule& mod, const Target& target, + const String& workload_name) final { + if (f_query_schedule == nullptr) { + return DatabaseNode::QuerySchedule(mod, target, workload_name); + } else { + return f_query_schedule(mod, target, workload_name); + } + } + + Optional QueryIRModule(const IRModule& mod, const Target& target, + const String& workload_name) final { + if (f_query_ir_module == nullptr) { + return DatabaseNode::QueryIRModule(mod, target, workload_name); + } else { + return f_query_ir_module(mod, target, workload_name); + } + } + int64_t Size() final { ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!"; return f_size(); @@ -380,6 +444,9 @@ class Database : public runtime::ObjectRef { * \param f_commit_tuning_record The packed function of `CommitTuningRecord`. * \param f_get_top_k The packed function of `GetTopK`. * \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`. + * \param f_query_tuning_record The packed function of `QueryTuningRecord`. + * \param f_query_schedule The packed function of `QuerySchedule`. + * \param f_query_ir_module The packed function of `QueryIRModule`. * \param f_size The packed function of `Size`. * \return The created database. */ @@ -388,6 +455,9 @@ class Database : public runtime::ObjectRef { PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, + PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, + PyDatabaseNode::FQuerySchedule f_query_schedule, + PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size); /*! \return The current Database in the scope. */ static Optional Current(); diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 7a1338f46b20..75b78b118eea 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -378,6 +378,9 @@ def __init__( f_commit_tuning_record: Callable = None, f_get_top_k: Callable = None, f_get_all_tuning_records: Callable = None, + f_query_tuning_record: Callable = None, + f_query_schedule: Callable = None, + f_query_ir_module: Callable = None, f_size: Callable = None, ): """Constructor.""" @@ -389,6 +392,9 @@ def __init__( f_commit_tuning_record, f_get_top_k, f_get_all_tuning_records, + f_query_tuning_record, + f_query_schedule, + f_query_ir_module, f_size, ) @@ -409,6 +415,9 @@ class PyDatabase: "commit_tuning_record", "get_top_k", "get_all_tuning_records", + "query_tuning_record", + "query_schedule", + "query_ir_module", "__len__", ], } @@ -478,6 +487,78 @@ def get_all_tuning_records(self) -> List[TuningRecord]: """ raise NotImplementedError + def query_tuning_record( + self, mod: IRModule, target: Target, workload_name: Optional[str] = None + ) -> Optional[TuningRecord]: + """Query a tuning record from the database. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + target : Target + The target to be searched for. + workload_name : Optional[str] + The workload name to be searched for. + + Returns + ------- + record : Optional[TuningRecord] + The tuning record corresponding to the given workload. + """ + # Using self._outer to replace the self pointer + return _ffi_api.DatabaseQueryTuningRecord( # type: ignore # pylint: disable=no-member + self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member + ) + + def query_schedule( + self, mod: IRModule, target: Target, workload_name: Optional[str] = None + ) -> Optional[Schedule]: + """Query a schedule from the database. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + target : Target + The target to be searched for. + workload_name : Optional[str] + The workload name to be searched for. + + Returns + ------- + schedule : Optional[Schedule] + The schedule corresponding to the given workload. + """ + # Using self._outer to replace the self pointer + return _ffi_api.DatabaseQuerySchedule( # type: ignore # pylint: disable=no-member + self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member + ) + + def query_ir_module( + self, mod: IRModule, target: Target, workload_name: Optional[str] = None + ) -> Optional[IRModule]: + """Query an IRModule from the database. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + target : Target + The target to be searched for. + workload_name : Optional[str] + The workload name to be searched for. + + Returns + ------- + mod : Optional[IRModule] + The IRModule corresponding to the given workload. + """ + # Using self._outer to replace the self pointer + return _ffi_api.DatabaseQueryIRModule( # type: ignore # pylint: disable=no-member + self._outer(), mod, target, workload_name # type: ignore # pylint: disable=no-member + ) + def __len__(self) -> int: """Get the number of records in the database. diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index d082ff7a3901..0976e158aaf0 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -217,6 +217,9 @@ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, + PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, + PyDatabaseNode::FQuerySchedule f_query_schedule, + PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size) { ObjectPtr n = make_object(); n->f_has_workload = f_has_workload; @@ -224,6 +227,9 @@ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, n->f_commit_tuning_record = f_commit_tuning_record; n->f_get_top_k = f_get_top_k; n->f_get_all_tuning_records = f_get_all_tuning_records; + n->f_query_tuning_record = f_query_tuning_record; + n->f_query_schedule = f_query_schedule; + n->f_query_ir_module = f_query_ir_module; n->f_size = f_size; return Database(n); } diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index e6342f1c3536..777c5589a141 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -18,11 +18,13 @@ """Test Meta Schedule Database""" import os.path as osp import tempfile -from typing import Callable +from typing import Callable, Optional, List import tvm import tvm.testing +from tvm.target import Target from tvm import meta_schedule as ms +from tvm.meta_schedule.database import TuningRecord, Workload from tvm import tir from tvm.ir.module import IRModule from tvm.script import tir as T @@ -106,6 +108,123 @@ def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord): assert str(arg0.as_json()) == str(arg1.as_json()) +@ms.utils.derived_object +class PyMemoryDatabaseDefault(ms.database.PyDatabase): + def __init__(self): + super().__init__() + self.tuning_records_: List[TuningRecord] = [] + self.workloads_: List[Workload] = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workloads_: + if tvm.ir.structural_equal(mod, workload.mod): + return True + + def commit_workload(self, mod: IRModule) -> ms.database.Workload: + if self.has_workload(mod): + for workload in self.workloads_: + if tvm.ir.structural_equal(mod, workload.mod): + return workload + else: + workload = ms.database.Workload(mod) + self.workloads_.append(workload) + return workload + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.tuning_records_.append(record) + + def get_all_tuning_records(self) -> List[TuningRecord]: + return self.tuning_records_ + + def get_top_k(self, workload: ms.database.Workload, top_k: int) -> List[TuningRecord]: + return sorted( + list( + filter( + lambda x: tvm.ir.structural_equal(workload.mod, x.workload.mod), + self.tuning_records_, + ) + ), + key=lambda x: sum(x.run_secs) / len(x.run_secs) if x.run_secs else 1e9, + )[:top_k] + + def __len__(self) -> int: + return len(self.tuning_records_) + + +@ms.utils.derived_object +class PyMemoryDatabaseOverride(ms.database.PyDatabase): + def __init__(self): + super().__init__() + self.tuning_records_: List[TuningRecord] = [] + self.workloads_: List[Workload] = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workloads_: + if tvm.ir.structural_equal(mod, workload.mod): + return True + + def commit_workload(self, mod: IRModule) -> ms.database.Workload: + if self.has_workload(mod): + for workload in self.workloads_: + if tvm.ir.structural_equal(mod, workload.mod): + return workload + else: + workload = ms.database.Workload(mod) + self.workloads_.append(workload) + return workload + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.tuning_records_.append(record) + + def get_all_tuning_records(self) -> List[TuningRecord]: + return self.tuning_records_ + + def get_top_k(self, workload: ms.database.Workload, top_k: int) -> List[TuningRecord]: + return sorted( + list( + filter( + lambda x: tvm.ir.structural_equal(workload.mod, x.workload.mod), + self.tuning_records_, + ) + ), + key=lambda x: sum(x.run_secs) / len(x.run_secs) if x.run_secs else 1e9, + )[:top_k] + + def __len__(self) -> int: + return len(self.tuning_records_) + + def query_tuning_record( + self, mod: IRModule, target: Target, workload_name: Optional[str] = None + ) -> Optional[TuningRecord]: + if self.has_workload(mod): + records = self.get_top_k(self.commit_workload(mod), 2) + if len(records) == 1: + return records[0] + elif len(records) == 2: + return records[1] # return the 2nd best if there are two records + return None + + def query_schedule( + self, mod: IRModule, target: Target, workload_name: Optional[str] = None + ) -> Optional[Schedule]: + record = self.query_tuning_record(mod, target, workload_name) + if record is not None: + sch = Schedule(record.workload.mod) + record.trace.apply_to_schedule(sch, remove_postproc=False) + return sch + return None + + def query_ir_module( + self, mod: IRModule, target: Target, workload_name: Optional[str] = None + ) -> Optional[IRModule]: + record = self.query_tuning_record(mod, target, workload_name) + if record is not None: + sch = Schedule(record.workload.mod) + record.trace.apply_to_schedule(sch, remove_postproc=False) + return sch.mod + return None + + def test_meta_schedule_tuning_record_round_trip(): mod: IRModule = Matmul with tempfile.TemporaryDirectory() as tmpdir: @@ -302,10 +421,10 @@ def test_meta_schedule_database_union(): db_2 = ms.database.MemoryDatabase() trace = _create_schedule(mod, _schedule_matmul).trace - def query(db): + def query(db): # pylint: disable=invalid-name return db.query_tuning_record(mod=mod, target=target, workload_name="main").run_secs - def commit_record(db, run_sec): + def commit_record(db, run_sec): # pylint: disable=invalid-name db.commit_tuning_record( ms.database.TuningRecord( trace, @@ -331,5 +450,91 @@ def commit_record(db, run_sec): assert run_secs.value == 1.0 +def test_meta_schedule_pydatabase_default_query(): + + mod: IRModule = Matmul + target = tvm.target.Target("llvm") + arg_info = ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]) + db = PyMemoryDatabaseDefault() # pylint: disable=invalid-name + sch = _create_schedule(mod, _schedule_matmul) + trace = sch.trace + + def query(db, mod, target, kind): # pylint: disable=invalid-name + return db.query(mod=mod, target=target, workload_name="main", kind=kind) + + def commit_record(trace, db, run_sec): # pylint: disable=invalid-name + db.commit_tuning_record( + ms.database.TuningRecord( + trace, + workload=db.commit_workload(mod), + run_secs=[run_sec], + target=target, + args_info=arg_info, + ) + ) + + commit_record(trace, db, 1.0) + record = query(db, mod, target, "record") + assert record is not None and record.run_secs[0].value == 1.0 + sch_res = query(db, mod, target, "schedule") + assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, sch.mod) + mod_res = query(db, mod, target, "ir_module") + assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod) + + commit_record(Schedule(mod).trace, db, 0.2) # Empty Trace + record = query(db, mod, target, "record") + assert record is not None and record.run_secs[0].value == 0.2 + sch_res = query(db, mod, target, "schedule") + assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, mod) + mod_res = query(db, mod, target, "ir_module") + assert mod_res is not None and tvm.ir.structural_equal(mod_res, mod) + + +def test_meta_schedule_pydatabase_override_query(): + + mod: IRModule = Matmul + target = tvm.target.Target("llvm") + arg_info = ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]) + db = PyMemoryDatabaseOverride() # pylint: disable=invalid-name + sch = _create_schedule(mod, _schedule_matmul) + trace = sch.trace + + def query(db, mod, target, kind): # pylint: disable=invalid-name + return db.query(mod=mod, target=target, workload_name="main", kind=kind) + + def commit_record(trace, db, run_sec): # pylint: disable=invalid-name + db.commit_tuning_record( + ms.database.TuningRecord( + trace, + workload=db.commit_workload(mod), + run_secs=[run_sec], + target=target, + args_info=arg_info, + ) + ) + + commit_record(trace, db, 1.14) + record = query(db, mod, target, "record") + assert record is not None and record.run_secs[0].value == 1.14 + sch_res = query(db, mod, target, "schedule") + assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, sch.mod) + mod_res = query(db, mod, target, "ir_module") + assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod) + + commit_record(Schedule(mod).trace, db, 0.514) # Empty Trace + record = query(db, mod, target, "record") + assert record is not None and record.run_secs[0].value == 1.14 # Override to 2nd best + sch_res = query(db, mod, target, "schedule") + assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, sch.mod) + mod_res = query(db, mod, target, "ir_module") + assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod) + + +def test_meta_schedule_pydatabase_current(): + db = PyMemoryDatabaseDefault() # pylint: disable=invalid-name + with db: # pylint: disable=not-context-manager + assert ms.database.Database.current() == db + + if __name__ == "__main__": tvm.testing.main()