diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 88db2e2277867..fa488a38ce0a2 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -357,6 +357,22 @@ class Database : public runtime::ObjectRef { */ TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing); + /*! + * \brief A database composed of multiple databases, allowing users to guide IR rewriting using + * combined knowledge of those databases. To each query, it returns the best record among all the + * databases given. + * \param databases The list of databases to be combined. + * \return The combined database. + */ + TVM_DLL static Database UnionDatabase(Array databases); + /*! + * \brief A database composed of multiple databases, allowing users to guide IR rewriting using + * combined knowledge of those databases. To each query, it returns the record from the first + * database that responds to the query. + * \param databases The database to be subsetted. + * \return The subsetted database. + */ + TVM_DLL static Database OrderedUnionDatabase(Array databases); /*! * \brief Create a database with customized methods on the python-side. * \param f_has_workload The packed function of `HasWorkload`. diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py index 7726daf6eb633..679923e47936e 100644 --- a/python/tvm/meta_schedule/database/__init__.py +++ b/python/tvm/meta_schedule/database/__init__.py @@ -21,4 +21,6 @@ from .database import Database, PyDatabase, TuningRecord, Workload from .json_database import JSONDatabase from .memory_database import MemoryDatabase +from .ordered_union_database import OrderedUnionDatabase from .schedule_fn_database import ScheduleFnDatabase +from .union_database import UnionDatabase diff --git a/python/tvm/meta_schedule/database/ordered_union_database.py b/python/tvm/meta_schedule/database/ordered_union_database.py new file mode 100644 index 0000000000000..35b0a9e282c10 --- /dev/null +++ b/python/tvm/meta_schedule/database/ordered_union_database.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A database consists of multiple databases.""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .database import Database + + +@register_object("meta_schedule.OrderedUnionDatabase") +class OrderedUnionDatabase(Database): + """A database composed of multiple databases, allowing users to guide IR rewriting using + combined knowledge of those databases. To each query, it returns the record from the first + database that responds to the query. + + Examples + -------- + Examples below demonstrate the usecases of and difference between UnionDatabase and + OrderDatabase. + + Assumption: + * db1, db2 do not have tuning records for the target workload. + * Each of db3, db4, db5 has tuning records r3, r4, r5 for target workload respectively. + + .. code-block:: python + + #### Case 1. `UnionDatabase`: + merged_db = ms.database.UnionDatabase( + db1, # no record + db2, # no record + db3, # has r3 + db4 # has r4 + ) + # returns the better one between r3 and r4 + merged_db.query_tuning_record(..., target_workload) + + ### Case 2. `OrderedUnionDatabase` + merged_db = ms.database.OrderedUnionDatabase( + db1, # no record + db2, # no record + db3, # has r3 + db4 # has r4 + ) + # returns r3 + merged_db.query_tuning_record(..., target_workload) + + ### Case 3. Mix-use scenario + merged_db = ms.database.UnionDatabase( + db1, # no record + db2, # no record + db3, # has r3 + ms.database.OrderedUnionDatabase( # returns r4 + db4, # has r4 + db5, # has r5 + ) + ) + # returns the better one between r3 and r4 + merged_db.query_tuning_record(..., target_workload) + + ### Case 4. Another mix-use scenario + merged_db = ms.database.UnionDatabase( + db1, # no record + db2, # no record + db3, # has r3 + ms.database.UnionDatabase( # returns best one between r4 and r5 + db4, # has r4 + db5, # has r5 + ) + ) + # returns the best one among r3, r4 and r5 + merged_db.query_tuning_record(..., target_workload) + + ### Case 5. Yet another mix-use scenario + merged_db = ms.database.OrderedUnionDatabase( + db1, # no record + db2, # no record + ms.database.UnionDatabase( # returns best one between r3 and r4 + db3, # has r3 + db4, # has r4 + ) + db5, # has r5 + ) + # returns the better one between r3 and r4 + merged_db.query_tuning_record(..., target_workload) + """ + + def __init__(self, *databases: Database) -> None: + """Construct a merged database from multiple databases. + + Parameters + ---------- + *databases : Database + The list of databases to combine. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseOrderedUnionDatabase, # type: ignore # pylint: disable=no-member + databases, + ) diff --git a/python/tvm/meta_schedule/database/union_database.py b/python/tvm/meta_schedule/database/union_database.py new file mode 100644 index 0000000000000..ae55ebe796145 --- /dev/null +++ b/python/tvm/meta_schedule/database/union_database.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A database consists of multiple databases.""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .database import Database + + +@register_object("meta_schedule.UnionDatabase") +class UnionDatabase(Database): + """A database composed of multiple databases, allowing users to guide IR rewriting using + combined knowledge of those databases. To each query, it returns the best record among all the + databases given. + + Examples + -------- + Examples below demonstrate the usecases of and difference between UnionDatabase and + OrderDatabase. + + Assumption: + * db1, db2 do not have tuning records for the target workload. + * Each of db3, db4, db5 has tuning records r3, r4, r5 for target workload respectively. + + .. code-block:: python + + #### Case 1. `UnionDatabase`: + merged_db = ms.database.UnionDatabase( + db1, # no record + db2, # no record + db3, # has r3 + db4 # has r4 + ) + # returns the better one between r3 and r4 + merged_db.query_tuning_record(..., target_workload) + + ### Case 2. `OrderedUnionDatabase` + merged_db = ms.database.OrderedUnionDatabase( + db1, # no record + db2, # no record + db3, # has r3 + db4 # has r4 + ) + # returns r3 + merged_db.query_tuning_record(..., target_workload) + + ### Case 3. Mix-use scenario + merged_db = ms.database.UnionDatabase( + db1, # no record + db2, # no record + db3, # has r3 + ms.database.OrderedUnionDatabase( # returns r4 + db4, # has r4 + db5, # has r5 + ) + ) + # returns the better one between r3 and r4 + merged_db.query_tuning_record(..., target_workload) + + ### Case 4. Another mix-use scenario + merged_db = ms.database.UnionDatabase( + db1, # no record + db2, # no record + db3, # has r3 + ms.database.UnionDatabase( # returns best one between r4 and r5 + db4, # has r4 + db5, # has r5 + ) + ) + # returns the best one among r3, r4 and r5 + merged_db.query_tuning_record(..., target_workload) + + ### Case 5. Yet another mix-use scenario + merged_db = ms.database.OrderedUnionDatabase( + db1, # no record + db2, # no record + ms.database.UnionDatabase( # returns best one between r3 and r4 + db3, # has r3 + db4, # has r4 + ) + db5, # has r5 + ) + # returns the better one between r3 and r4 + merged_db.query_tuning_record(..., target_workload) + """ + + def __init__(self, *databases: Database) -> None: + """Construct a merged database from multiple databases. + + Parameters + ---------- + *databases : Database + The list of databases to combine. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseUnionDatabase, # type: ignore # pylint: disable=no-member + databases, + ) diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 2e4f852608353..91b96c82479f9 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -25,28 +25,6 @@ namespace tvm { namespace meta_schedule { -/*! \brief The struct defining comparison function of sorting by mean run seconds. */ -struct SortTuningRecordByMeanRunSecs { - static const constexpr double kMaxMeanTime = 1e10; - - static double Mean(const Array& a) { - if (a.empty()) { - return kMaxMeanTime; - } - double sum = 0.0; - for (const FloatImm& i : a) { - sum += i->value; - } - return sum / a.size(); - } - - bool operator()(const TuningRecord& a, const TuningRecord& b) const { - double a_time = Mean(a->run_secs.value_or({})); - double b_time = Mean(b->run_secs.value_or({})); - return a_time < b_time; - } -}; - /*! * \brief Read lines from a json file. * \param path The path to the json file. diff --git a/src/meta_schedule/database/ordered_union_database.cc b/src/meta_schedule/database/ordered_union_database.cc new file mode 100644 index 0000000000000..3aaee2112c0c4 --- /dev/null +++ b/src/meta_schedule/database/ordered_union_database.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class OrderedUnionDatabaseNode : public DatabaseNode { + public: + Array databases; + + void VisitAttrs(AttrVisitor* v) { v->Visit("databases", &databases); } + + static constexpr const char* _type_key = "meta_schedule.OrderedUnionDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(OrderedUnionDatabaseNode, DatabaseNode); + + public: + Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const String& task_name) final { + for (const Database& db : databases) { + if (Optional record = db->QueryTuningRecord(mod, target, task_name)) { + return record; + } + } + return NullOpt; + } + + bool HasWorkload(const IRModule& mod) final { + LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.HasWorkload"; + throw; + } + + Workload CommitWorkload(const IRModule& mod) final { + LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.CommitWorkload"; + throw; + } + + void CommitTuningRecord(const TuningRecord& record) final { + LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.CommitTuningRecord"; + throw; + } + + Array GetTopK(const Workload& workload, int top_k) final { + LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetTopK"; + throw; + } + + Array GetAllTuningRecords() final { + LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetAllTuningRecords"; + throw; + } + + int64_t Size() final { + LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.size"; + throw; + } +}; + +Database Database::OrderedUnionDatabase(Array databases) { + ObjectPtr n = make_object(); + n->databases = std::move(databases); + return Database(n); +} + +TVM_REGISTER_NODE_TYPE(OrderedUnionDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase") + .set_body_typed(Database::OrderedUnionDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/database/union_database.cc b/src/meta_schedule/database/union_database.cc new file mode 100644 index 0000000000000..6d19a38c6d9e3 --- /dev/null +++ b/src/meta_schedule/database/union_database.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class UnionDatabaseNode : public DatabaseNode { + public: + Array databases; + + void VisitAttrs(AttrVisitor* v) { v->Visit("databases", &databases); } + + static constexpr const char* _type_key = "meta_schedule.UnionDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnionDatabaseNode, DatabaseNode); + + public: + Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const String& task_name) final { + std::vector results; + results.reserve(databases.size()); + for (const Database& db : databases) { + if (Optional record = db->QueryTuningRecord(mod, target, task_name)) { + results.push_back(record.value()); + } + } + std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs()); + return results.empty() ? Optional(NullOpt) : results[0]; + } + + bool HasWorkload(const IRModule& mod) final { + LOG(FATAL) << "NotImplementedError: UnionDatabase.HasWorkload"; + throw; + } + + Workload CommitWorkload(const IRModule& mod) final { + LOG(FATAL) << "NotImplementedError: UnionDatabase.CommitWorkload"; + throw; + } + + void CommitTuningRecord(const TuningRecord& record) final { + LOG(FATAL) << "NotImplementedError: UnionDatabase.CommitTuningRecord"; + throw; + } + + Array GetTopK(const Workload& workload, int top_k) final { + LOG(FATAL) << "NotImplementedError: UnionDatabase.GetTopK"; + throw; + } + + Array GetAllTuningRecords() final { + LOG(FATAL) << "NotImplementedError: UnionDatabase.GetAllTuningRecords"; + throw; + } + + int64_t Size() final { + LOG(FATAL) << "NotImplementedError: UnionDatabase.size"; + throw; + } +}; + +Database Database::UnionDatabase(Array databases) { + ObjectPtr n = make_object(); + n->databases = std::move(databases); + return Database(n); +} + +TVM_REGISTER_NODE_TYPE(UnionDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase").set_body_typed(Database::UnionDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index db37935ec2063..ad56fa7f6a526 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -404,6 +404,28 @@ inline Array AsIntArray(const ObjectRef& obj) { return results; } +/*! \brief The struct defining comparison function of sorting by mean run seconds. */ +struct SortTuningRecordByMeanRunSecs { + static const constexpr double kMaxMeanTime = 1e10; + + static double Mean(const Array& a) { + if (a.empty()) { + return kMaxMeanTime; + } + double sum = 0.0; + for (const FloatImm& i : a) { + sum += i->value; + } + return sum / a.size(); + } + + bool operator()(const TuningRecord& a, const TuningRecord& b) const { + double a_time = Mean(a->run_secs.value_or({})); + double b_time = Mean(b->run_secs.value_or({})); + return a_time < b_time; + } +}; + } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index b14c18e55f4b8..e5b8cd77445fa 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -412,17 +412,12 @@ def schedule_fn(sch): return True return False - link_params = True - with StringIO() as stderr_buf, redirect_stderr(stderr_buf): with ms.database.ScheduleFnDatabase(schedule_fn), tvm.transform.PassContext( opt_level=3, - config={ - "relay.backend.use_meta_schedule": True, - "relay.FuseOps.link_params": link_params, - }, + config={"relay.backend.use_meta_schedule": True}, ): - executor = Executor("graph", {"link-params": link_params}) + executor = Executor("graph", {"link-params": True}) lib = relay.build(relay_mod, target=target, executor=executor) # Workload look up should succeed. This does not work when the test is invoked from pytest. diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index ff0f350d89147..e6342f1c35364 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -294,5 +294,42 @@ def test_meta_schedule_database_reload(): _equal_record(ret[1], records[2]) +def test_meta_schedule_database_union(): + mod: IRModule = Matmul + target = tvm.target.Target("llvm") + arg_info = ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]) + db_1 = ms.database.MemoryDatabase() + db_2 = ms.database.MemoryDatabase() + trace = _create_schedule(mod, _schedule_matmul).trace + + def query(db): + return db.query_tuning_record(mod=mod, target=target, workload_name="main").run_secs + + def commit_record(db, run_sec): + db.commit_tuning_record( + ms.database.TuningRecord( + trace, + workload=db.commit_workload(mod), + run_secs=[run_sec], + target=target, + args_info=arg_info, + ) + ) + + commit_record(db_1, 1.0) + (run_sec,) = query(db_1) + assert run_sec.value == 1.0 + + commit_record(db_2, 0.5) + (run_sec,) = query(db_2) + assert run_sec.value == 0.5 + + (run_secs,) = query(ms.database.UnionDatabase(db_1, db_2)) + assert run_secs.value == 0.5 + + (run_secs,) = query(ms.database.OrderedUnionDatabase(db_1, db_2)) + assert run_secs.value == 1.0 + + if __name__ == "__main__": tvm.testing.main()