From 734a22df45fd169beec633553448d704ad549d11 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 25 Jan 2022 23:45:24 -0800 Subject: [PATCH] [Meta Schedule] Add `ApplyHisotryBest` Meta Schedule Context (#10049) * Add ApplyHisotryBest. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng * Retrigger CI. * Update integration.py Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- python/tvm/meta_schedule/integration.py | 9 ++- src/meta_schedule/integration.cc | 22 ++++++- src/meta_schedule/utils.h | 16 +++++ .../test_meta_schedule_integration.py | 58 +++++++++++++++++++ 4 files changed, 103 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 47003c6faa25..794591cefed3 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -25,6 +25,7 @@ from tvm.target import Target from tvm.tir import PrimFunc +from .database import Database from . import _ffi_api @@ -174,7 +175,13 @@ def __init__(self) -> None: @register_object("meta_schedule.ApplyHistoryBest") class ApplyHistoryBest(MetaScheduleContext): - pass + """An integration context that allows application of historically best record from database""" + + database: Database + """ The database to be queried from""" + + def __init__(self, database) -> None: + self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member def extract_task( diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index cf4262814947..e9d3012f789d 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -20,6 +20,8 @@ #include #include +#include "./utils.h" + namespace tvm { namespace meta_schedule { @@ -112,7 +114,21 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) { Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Optional> dispatched) { - throw; + ICHECK(dispatched.defined()); + ICHECK_EQ(dispatched.value().size(), 1); + ICHECK(HasOnlyOneFunction(mod)) << mod; + IRModule prim_mod = dispatched.value()[0]; + ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + // Unify func name to make sure it can be found in database + prim_mod = UnifyFuncName(prim_mod); + if (database->HasWorkload(prim_mod)) { + Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); + if (records.size() == 1) { + LOG(INFO) << "Applied history best for " << task_name << "."; + return records[0]->workload->mod; + } + } + return NullOpt; } /**************** FFI ****************/ @@ -146,6 +162,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { return TaskExtraction(); }); +TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") + .set_body_typed([](Database database) -> ApplyHistoryBest { + return ApplyHistoryBest(database); + }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index bd76ca794a9a..afeb159052ee 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -351,6 +351,22 @@ inline int GetTargetNumCores(const Target& target) { return num_cores; } +/*! + * \brief Unify the function name in workload to "main". + * \param mod The workload. + * \return The new workload with unified function name. + * \note If the name is not unified, the workload may not be found in database. + */ +inline IRModule UnifyFuncName(const IRModule& mod) { + if (!mod->ContainGlobalVar("main") && mod->GetGlobalTypeVars().size() == 1) { + IRModule new_mod = IRModule( + Map({{GlobalVar("main"), mod->functions[mod->GetGlobalVars()[0]]}})); + return new_mod; + } else { + return mod; + } +} + } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index f508c7d252e1..bc1d5f268ba0 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -22,10 +22,14 @@ import tvm from tvm import meta_schedule as ms from tvm.ir.module import IRModule +from tvm.tir import Schedule +from tvm.target import Target +from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord from tvm.meta_schedule.integration import ( ExtractedTask, MetaScheduleContext, TaskExtraction, + ApplyHistoryBest, ) from tvm.meta_schedule.testing import get_network from tvm.script import tir as T @@ -116,5 +120,59 @@ def test_meta_schedule_integration_extract_from_resnet(): assert len(extracted_tasks) == 30 +def test_meta_schedule_integration_apply_history_best(): + class DummyDatabase(PyDatabase): + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + mod, _, _, _ = get_network( + name="resnet-18", + batch_size=1, + layout="NHWC", + dtype="float32", + ) + database = DummyDatabase() + env = ApplyHistoryBest(database) + workload = database.commit_workload(MockModule) + database.commit_tuning_record( + TuningRecord(Schedule(MockModule).trace, [1.0], workload, Target("llvm"), []) + ) + mod = env.query(task_name="mock-task", mod=mod, dispatched=[MockModule]) + assert tvm.ir.structural_equal(mod, workload.mod) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))