From 837e2f78b5e4a1eb67bbb34325d911d6fb466091 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 26 Aug 2022 20:20:42 -0700 Subject: [PATCH] [MetaSchedule][UX] Make `Database` with-able (#12520) `ApplyHistoryBest` right now plays a role as the database adaptor to query inside the database. In fact, the logic could be simplified and users only have to deal with `Database` instead of this extra object. - [x] Add `EnterWithScope`/`ExitWithScope`/`Current` to Database - [x] Migrate `te_filter_func` => "tir_filter" in Relay's pass context - [x] Migrate `f_take_tuning_record` => "Database.query_tuning_record" - [x] Migrate `TECompiler` to use `Database` - [x] Remove apply-history-best Next PR: - Migrate `f_direct_dispatch` (potentially unify with `apply_fixed_schedule`?) --- .../tvm/meta_schedule/apply_history_best.h | 115 ------------ include/tvm/meta_schedule/database.h | 28 +++ include/tvm/meta_schedule/extracted_task.h | 20 --- .../tvm/auto_scheduler/testing/tune_relay.py | 93 +++++----- python/tvm/meta_schedule/__init__.py | 1 - .../tvm/meta_schedule/apply_history_best.py | 130 -------------- python/tvm/meta_schedule/database/database.py | 104 ++++++++++- python/tvm/meta_schedule/default_config.py | 4 - python/tvm/meta_schedule/relay_integration.py | 29 ++- .../tvm/meta_schedule/testing/tune_relay.py | 30 +++- python/tvm/meta_schedule/testing/utils.py | 26 +-- python/tvm/meta_schedule/tune.py | 12 +- src/meta_schedule/apply_history_best.cc | 165 ------------------ src/meta_schedule/database/database.cc | 64 +++++++ src/meta_schedule/extracted_task.cc | 70 -------- src/meta_schedule/utils.h | 1 - src/relay/backend/task_extraction.cc | 25 +-- src/relay/backend/te_compiler.cc | 1 + src/relay/backend/te_compiler_cache.cc | 70 ++++---- src/relay/backend/utils.cc | 73 ++++++++ src/relay/backend/utils.h | 31 ++++ .../test_meta_schedule_auto_tensorize.py | 25 ++- tests/python/unittest/test_link_params.py | 19 +- .../test_meta_schedule_integration.py | 62 +------ .../test_meta_schedule_multi_anchor.py | 2 +- .../test_meta_schedule_relay_tir_compute.py | 18 +- .../unittest/test_meta_schedule_tune_relay.py | 57 +++--- 27 files changed, 511 insertions(+), 764 deletions(-) delete mode 100644 include/tvm/meta_schedule/apply_history_best.h delete mode 100644 python/tvm/meta_schedule/apply_history_best.py delete mode 100644 src/meta_schedule/apply_history_best.cc diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h deleted file mode 100644 index 44a34b3ee496..000000000000 --- a/include/tvm/meta_schedule/apply_history_best.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_ -#define TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace meta_schedule { - -/*! - * \brief An integration context that allows application of historically best records from a - * database - */ -class ApplyHistoryBestNode : public runtime::Object { - public: - /*! \brief A callback function that filters TE compute */ - using FTEFilterFunc = runtime::TypedPackedFunc( - const Array&, const Array&)>; - /*! \brief A callback function that takes a tuning record and does something with it */ - using FTakeTuningRecord = runtime::TypedPackedFunc; - using FDirectDispatch = runtime::TypedPackedFunc(const IRModule&)>; - - /*! \brief The database to be queried from */ - Database database{nullptr}; - /*! \brief The filtering function for TE computation */ - FTEFilterFunc te_filter_func{nullptr}; - /*! \brief The logging function to be used */ - PackedFunc logging_func; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("database", &database); - // `te_filter_func` is not visited - // `logging_func` is not visited - } - /*! - * \brief Query the best entry from the database - * \param task_name The name of the task to be queried - * \param mod The module to be queried - * \param target The target to be queried - * \param dispatched The IRs after dispatch - * \param f_take_tuning_record A callback function that takes a tuning record and does something - * with it. - * \param f_direct_dispatch A function that directly dispatches an IRModule to the given workload - * as result if available, skipping the database query. - */ - Optional Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched, - FTakeTuningRecord f_take_tuning_record, - FDirectDispatch f_direct_dispatch = nullptr); - - static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; - TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, runtime::Object); -}; - -/*! - * \brief Managed reference to ApplyHistoryBestNode - * \sa ApplyHistoryBestNode - */ -class ApplyHistoryBest : public runtime::ObjectRef { - public: - /*! - * \brief Constructor - * \param database The database to be queried from - * \param te_filter_func The filtering function for TE computation - * \param logging_func The logging function to use - */ - explicit ApplyHistoryBest(Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func, - PackedFunc logging_func); - /*! - * \brief The current ApplyHistoryBest in the context - * \return The ApplyHistoryBest in the current scope. - */ - static Optional Current(); - - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, runtime::ObjectRef, - ApplyHistoryBestNode); - - protected: - friend class ApplyHistoryBestInternal; - /*! \brief Entering the scope of the context manager */ - void EnterWithScope(); - /*! \brief Exiting the scope of the context manager */ - void ExitWithScope(); -}; - -} // namespace meta_schedule -} // namespace tvm - -#endif // TVM_META_SCHEDULE_APPLY_HISTORY_BEST_H_ diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 1c260d9d748a..0e7f45d39332 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -203,6 +203,27 @@ class DatabaseNode : public runtime::Object { * \return The size of the database. */ virtual int64_t Size() = 0; + /*! + * \brief Query the best record of the given workload from the database. + * \param mod The IRModule to be searched for. + * \param target The target to be searched for. + * \return The best record of the given workload; NullOpt if not found. + */ + virtual Optional QueryTuningRecord(IRModule mod, Target target); + /*! + * \brief Query the best schedule of the given workload from the database. + * \param mod The IRModule to be searched for. + * \param target The target to be searched for. + * \return The schedule in the best schedule of the given workload; NullOpt if not found. + */ + virtual Optional QuerySchedule(IRModule mod, Target target); + /*! + * \brief Query the best IRModule of the given workload from the database. + * \param mod The IRModule to be searched for. + * \param target The target to be searched for. + * \return The IRModule in the best IRModule of the given workload; NullOpt if not found. + */ + virtual Optional QueryIRModule(IRModule mod, Target target); static constexpr const char* _type_key = "meta_schedule.Database"; TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); @@ -339,6 +360,13 @@ class Database : public runtime::ObjectRef { PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, PyDatabaseNode::FSize f_size); + /*! \return The current Database in the scope. */ + static Optional Current(); + /*! \brief Entering the scope of the context manager */ + void EnterWithScope(); + /*! \brief Exiting the scope of the context manager */ + void ExitWithScope(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); }; diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index bce40e6b95f0..239bf0dc5777 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -76,26 +76,6 @@ class ExtractedTask : public runtime::ObjectRef { ExtractedTaskNode); }; -/*! - * \brief The default TE task filter - * \param args The input/output arguments of the TE compute graph - * \param constants Raw data for constant tensors in args. If the size of this array is N, the last - * N tensors in args will be treated as constant tensors. - * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc - */ -Optional DefaultTaskFilter(const Array& args, - const Array& constants); - -/*! - * \brief The default TE task filter, with `te.extern` allowed - * \param args The input/output arguments of the TE compute graph - * \param constants Raw data for constant tensors in args. If the size of this array is N, the last - * N tensors in args will be treated as constant tensors. - * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc - */ -Optional DefaultTaskFilterAllowExtern(const Array& args, - const Array& constants); - } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/auto_scheduler/testing/tune_relay.py b/python/tvm/auto_scheduler/testing/tune_relay.py index fe747af7972c..2d84389f9de1 100644 --- a/python/tvm/auto_scheduler/testing/tune_relay.py +++ b/python/tvm/auto_scheduler/testing/tune_relay.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from distutils.util import strtobool import argparse import json import os +from distutils.util import strtobool import tvm from tvm import auto_scheduler @@ -26,7 +26,7 @@ from tvm import relay from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer +from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data from tvm.meta_schedule.utils import cpu_count from tvm.support import describe @@ -170,53 +170,62 @@ def main(): ARGS.input_shape, cache_dir=ARGS.cache_dir, ) - input_info = {input_name: input_shape} + input_info = [ + { + "name": input_name, + "shape": input_shape, + "dtype": input_dtype, + }, + ] input_data = { - item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape + item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in input_info } - for input_name, input_shape in input_info.items(): - print(f" input_name : {input_name}") - print(f" input_shape: {input_shape}") - print(f" input_dtype: {input_dtype}") + for item in input_info: + print(f" input_name : {item['name']}") + print(f" input_shape: {item['shape']}") + print(f" input_dtype: {item['dtype']}") with ms.Profiler() as profiler: - tasks, task_weights = auto_scheduler.extract_tasks( - mod["main"], - params, - target=ARGS.target, - hardware_params=hardware_params, - ) - for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)): - print( - f"==== Task {idx}: {task.desc} " - f"(weight {task_weight} key: {task.workload_key}) =====" - ) - print(task.compute_dag) - - if ARGS.num_trials > 0: - tuner = auto_scheduler.TaskScheduler(tasks, task_weights) - tuner.tune( - auto_scheduler.TuningOptions( - num_measure_trials=ARGS.num_trials, - runner=runner, - measure_callbacks=[ - auto_scheduler.RecordToFile(log_file), - ], - ), - adaptive_training=ARGS.adaptive_training, + with ms.Profiler.timeit("TaskExtraction"): + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], + params, + target=ARGS.target, + hardware_params=hardware_params, ) + for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)): + print( + f"==== Task {idx}: {task.desc} " + f"(weight {task_weight} key: {task.workload_key}) =====" + ) + print(task.compute_dag) + + with ms.Profiler.timeit("Tuning"): + if ARGS.num_trials > 0: + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tuner.tune( + auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + runner=runner, + measure_callbacks=[ + auto_scheduler.RecordToFile(log_file), + ], + ), + adaptive_training=ARGS.adaptive_training, + ) relay_build = {"graph": relay.build, "vm": relay.vm.compile}[ARGS.backend] - with auto_scheduler.ApplyHistoryBest(log_file): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_auto_scheduler": True}, - ): - lib = relay_build( - mod, - target=ARGS.target, - params=params, - ) + with ms.Profiler.timeit("PostTuningCompilation"): + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_auto_scheduler": True}, + ): + lib = relay_build( + mod, + target=ARGS.target, + params=params, + ) print("Tuning Time:") print(profiler.table()) diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index f60d0a5490f5..cf348d49f4e2 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -30,7 +30,6 @@ search_strategy, space_generator, ) -from .apply_history_best import ApplyHistoryBest from .extracted_task import ExtractedTask from .profiler import Profiler from .relay_integration import ( diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py deleted file mode 100644 index a7b9b20bf244..000000000000 --- a/python/tvm/meta_schedule/apply_history_best.py +++ /dev/null @@ -1,130 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""A context manager that injects the best tuning record in the database into compilation""" -import logging -from typing import Callable, List, Optional, Union - -from tvm._ffi import get_global_func, register_object -from tvm.ir import IRModule -from tvm.runtime import Object -from tvm.target import Target -from tvm.te import Tensor -from tvm.tir import PrimFunc - -from . import _ffi_api -from .database import Database, TuningRecord -from .utils import make_logging_func - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -@register_object("meta_schedule.ApplyHistoryBest") -class ApplyHistoryBest(Object): - """An integration context that allows application of historically best records from a database - - Parameters - ---------- - database : Database - The database to be queried from - te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None - The filtering function for TE computation - If it's a string, it's the name of the filtering function. Built in functions are - - "meta_schedule.DefaultTaskFilter" - - "meta_schedule.DefaultTaskFilterAllowExtern" - If it's None, it's the default filtering function - If it's a callable, it's the filtering function - """ - - database: Database - - def __init__( - self, - database: Database, - te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None, - ) -> None: - if isinstance(te_filter_func, str): - te_filter_func = get_global_func(te_filter_func) - self.__init_handle_by_constructor__( - _ffi_api.ApplyHistoryBest, # type: ignore # pylint: disable=no-member - database, - te_filter_func, - make_logging_func(logger), - ) - - def query( - self, - task_name: str, - mod: IRModule, - target: Target, - dispatched: Optional[List[IRModule]], - f_take_tuning_record: Optional[Callable[[TuningRecord], None]] = None, - f_direct_dispatch: Optional[Callable[[IRModule], Optional[IRModule]]] = None, - ) -> Union[IRModule, None]: - """The entry point of the integration - - Parameters - ---------- - task_name : str - The name of the task extracted - mod : IRModule - The high-level IR - target: Target - Target Info - dispatched : Optional[List[IRModule]] - A list of low-level IRs that the high-level IR could potentially dispatch to - f_take_tuning_record : Optional[Callable[[TuningRecord], None]] = None - A callback function that takes a tuning record and does something with it - f_direct_dispatch : Optional[Callable[[IRModule], Optional[IRModule]]] = None - A function that directly dispatches an IRModule to the given workload as result if - available, skipping the database query. - - Returns - ------- - result : IRModule or None - Currently we only have to return tir::PrimFunc, but we wrap it under IRModule for - more general future use. None is returned if there is no feedback hint. - """ - return _ffi_api.ApplyHistoryBestQuery( # type: ignore # pylint: disable=no-member - self, - task_name, - mod, - target, - dispatched, - f_take_tuning_record, - f_direct_dispatch, - ) - - @staticmethod - def current() -> Optional["ApplyHistoryBest"]: - """The context manager in the current scope - - Returns - ------- - ctx : Optional[ApplyHistoryBest] - The ApplyHistoryBest context manager in the current scope. - None if it's currently not under any ApplyHistoryBest context. - """ - return _ffi_api.ApplyHistoryBestCurrent() # type: ignore # pylint: disable=no-member - - def __enter__(self) -> "ApplyHistoryBest": - """Entering the scope of the context manager""" - _ffi_api.ApplyHistoryBestEnterScope(self) # type: ignore # pylint: disable=no-member - return self - - def __exit__(self, ptype, value, trace) -> None: - """Exiting the scope of the context manager""" - _ffi_api.ApplyHistoryBestExitScope(self) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 0c11f77591cc..68283b4554e5 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. """TuningRecord database""" -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union from tvm._ffi import register_object from tvm.ir.module import IRModule from tvm.runtime import Object from tvm.target import Target -from tvm.tir.schedule import Trace +from tvm.tir.schedule import Schedule, Trace +from typing_extensions import Literal # pylint: disable=wrong-import-order from .. import _ffi_api from ..arg_info import ArgInfo @@ -234,6 +235,105 @@ def __len__(self) -> int: """ return _ffi_api.DatabaseSize(self) # type: ignore # pylint: disable=no-member + def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningRecord]: + """Query the best record of the given workload from the database. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + target : Target + The target to be searched for. + + Returns + ------- + tuning_record : Optional[TuningRecord] + The best record of the given workload; None if not found. + """ + return _ffi_api.DatabaseQueryTuningRecord(self, mod, target) # type: ignore # pylint: disable=no-member + + def query_schedule(self, mod: IRModule, target: Target) -> Optional[Schedule]: + """Query the best schedule of the given workload from the database. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + target : Target + The target to be searched for. + + Returns + ------- + schedule : Optional[Schedule] + The best schedule of the given workload; None if not found. + """ + return _ffi_api.DatabaseQuerySchedule(self, mod, target) # type: ignore # pylint: disable=no-member + + def query_ir_module(self, mod: IRModule, target: Target) -> Optional[IRModule]: + """Query the best IRModule of the given workload from the database. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + target : Target + The target to be searched for. + + Returns + ------- + ir_module : Optional[IRModule] + The best IRModule of the given workload; None if not found. + """ + return _ffi_api.DatabaseQueryIRModule(self, mod, target) # type: ignore # pylint: disable=no-member + + def query( + self, + mod: IRModule, + target: Target, + kind: Union[ + Literal["schedule"], + Literal["record"], + Literal["ir_module"], + ] = "schedule", + ) -> Union[Schedule, IRModule, TuningRecord]: + """Query the database to retrieve the best optimization outcome of the given workload. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + target : Target + The target to be searched for. + kind : str = "schedule" | "record" | "ir_module" + The kind of the optimization outcome to be returned. + + Returns + ------- + result : Union[Schedule, IRModule, TuningRecord] + The best optimization outcome of the given workload. + """ + if kind == "schedule": + return self.query_schedule(mod, target) + if kind == "record": + return self.query_tuning_record(mod, target) + if kind == "ir_module": + return self.query_ir_module(mod, target) + raise ValueError(f'Unknown kind: {kind}. Candidates are: "schedule", "record", "ir_module"') + + def __enter__(self) -> "Database": + """Entering the scope of the context manager""" + _ffi_api.DatabaseEnterWithScope(self) # type: ignore # pylint: disable=no-member + return self + + def __exit__(self, ptype, value, trace) -> None: + """Exiting the scope of the context manager""" + _ffi_api.DatabaseExitWithScope(self) # type: ignore # pylint: disable=no-member + + @staticmethod + def current() -> Optional["Database"]: + """Get the current database under scope.""" + return _ffi_api.DatabaseCurrent() # type: ignore # pylint: disable=no-member + @register_object("meta_schedule.PyDatabase") class _PyDatabase(Database): diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py index 97cbfc58a6c1..652f09261b2f 100644 --- a/python/tvm/meta_schedule/default_config.py +++ b/python/tvm/meta_schedule/default_config.py @@ -20,7 +20,6 @@ from os import path as osp from typing import Callable, Dict, List, Optional, Union -from tvm._ffi.registry import register_func from tvm.ir import IRModule from tvm.target import Target from tvm.tir import PrimFunc @@ -44,7 +43,6 @@ FnMutatorProb = Callable[[], Dict[Mutator, float]] -@register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest def mod(mod: Union[PrimFunc, IRModule]) -> IRModule: # pylint: disable=redefined-outer-name """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): @@ -53,8 +51,6 @@ def mod(mod: Union[PrimFunc, IRModule]) -> IRModule: # pylint: disable=redefine mod = IRModule({"main": mod}) if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") - # in order to make sure the mod can be found in ApplyHistoryBest - # different func name can cause structural unequal func_names = mod.get_global_vars() (func_name,) = func_names if len(func_names) == 1 and func_name != "main": diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index d3b3ea796532..24009ab07fcf 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """MetaSchedule-Relay integration""" -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import numpy as np # type: ignore from tvm import nd @@ -23,8 +23,6 @@ from tvm.ir import IRModule, transform from tvm.runtime import NDArray from tvm.target import Target -from tvm.te import Tensor -from tvm.tir import PrimFunc from .extracted_task import ExtractedTask from .utils import autotvm_silencer @@ -38,7 +36,7 @@ def extract_task_from_relay( opt_level: int = 3, pass_config: Optional[Dict[str, Any]] = None, disabled_pass: Optional[List[str]] = None, - te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None, + tir_converter: str = "default", ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -56,13 +54,13 @@ def extract_task_from_relay( The pass config of the compiler disabled_pass : Optional[List[str]] The list of disabled passes of the compiler - te_filter_func : Callable[[List[tvm.te.Tensor], List[NDArray]], bool] - The filter function to filter out the extracted tasks - If it's a string, it's the name of the filtering function. Built in functions are - - "meta_schedule.DefaultTaskFilter" - - "meta_schedule.DefaultTaskFilterAllowExtern" - If it's None, it's the default filtering function - If it's a callable, it's the filtering function + tir_converter : str + The filter function to filter out the extracted tasks. Builtin filters: + - "default" + - "allow_extern" + The converter is a PackedFunc registered as f"relay.backend.tir_converter.{tir_converter}", + with the signature below: + (args: List[te.Tensor], constants: List[NDArray]) -> Optional[tir.PrimFunc] Returns ------- @@ -75,8 +73,6 @@ def extract_task_from_relay( # pylint: enable=import-outside-toplevel - if isinstance(te_filter_func, str): - te_filter_func = get_global_func(te_filter_func) extract_task_func = get_global_func( "relay.backend.MetaScheduleExtractTask", allow_missing=False, @@ -89,7 +85,10 @@ def extract_task_from_relay( if disabled_pass is None: disabled_pass = [] if pass_config is None: - pass_config = {"relay.backend.use_meta_schedule": True} + pass_config = { + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": tir_converter, + } if params is None: params = {} relay_params = {} @@ -110,7 +109,7 @@ def extract_task_from_relay( else: tophub_context = autotvm.utils.EmptyContext() with tophub_context: - return list(extract_task_func(mod, target, relay_params, te_filter_func)) + return list(extract_task_func(mod, target, relay_params)) def is_meta_schedule_enabled() -> bool: diff --git a/python/tvm/meta_schedule/testing/tune_relay.py b/python/tvm/meta_schedule/testing/tune_relay.py index 8010e36fd656..596a5a736333 100644 --- a/python/tvm/meta_schedule/testing/tune_relay.py +++ b/python/tvm/meta_schedule/testing/tune_relay.py @@ -15,16 +15,18 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from distutils.util import strtobool import argparse import json import logging +from distutils.util import strtobool +from typing import Dict +import numpy as np # type: ignore import tvm from tvm import meta_schedule as ms from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer +from tvm.meta_schedule.testing.tune_utils import create_timer, generate_input_data from tvm.support import describe @@ -137,14 +139,24 @@ def main(): ARGS.input_shape, cache_dir=ARGS.cache_dir, ) - input_info = {input_name: input_shape} - input_data = { - item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape + input_info = [ + { + "name": input_name, + "shape": input_shape, + "dtype": input_dtype, + }, + ] + input_data: Dict[str, np.ndarray] = { + item["name"]: generate_input_data( # type: ignore + item["shape"], # type: ignore + item["dtype"], # type: ignore + ) + for item in input_info } - for input_name, input_shape in input_info.items(): - print(f" input_name : {input_name}") - print(f" input_shape: {input_shape}") - print(f" input_dtype: {input_dtype}") + for item in input_info: + print(f" input_name : {item['name']}") + print(f" input_shape: {item['shape']}") + print(f" input_dtype: {item['dtype']}") runner = ms.runner.RPCRunner( rpc_config=ARGS.rpc_config, diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index dda492008ffe..5919fb47c809 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -16,12 +16,13 @@ # under the License. """Testing utility functions in meta schedule""" from typing import Callable, Dict, Optional, Union + +from tvm import meta_schedule as ms from tvm.ir import IRModule, transform from tvm.relay import Function as RelayFunc from tvm.runtime import NDArray from tvm.target import Target from tvm.tir import Schedule -from tvm import meta_schedule as ms def apply_fixed_schedules( @@ -29,10 +30,10 @@ def apply_fixed_schedules( target: Union[str, Target], params: Optional[Dict[str, NDArray]], schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool], - te_filter_func=None, + tir_converter: str = "default", ): """Apply fixed schedules (manually written, without any tunable knobs) as specified by - schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest. + schedule_fn to extracted tasks, and return a database that can be passed to compilation. Parameters ---------- @@ -45,13 +46,13 @@ def apply_fixed_schedules( schedule_fn : Callable[[ExtractedTask, Schedule], bool] A callable that is applied for each extracted task and the corresponding default schedule. Returns True if the given schedule should be committed to the database, False otherwise. - te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None - The filtering function for TE computation - If it's a string, it's the name of the filtering function. Built in functions are - - "meta_schedule.DefaultTaskFilter" - - "meta_schedule.DefaultTaskFilterAllowExtern" - If it's None, it's the default filtering function - If it's a callable, it's the filtering function + tir_converter : str + The filter function to filter out the extracted tasks. Builtin filters: + - "default" + - "allow_extern" + The converter is a PackedFunc registered as f"relay.backend.tir_converter.{tir_converter}", + with the signature below: + (args: List[te.Tensor], constants: List[NDArray]) -> Optional[tir.PrimFunc] Returns ------- @@ -64,7 +65,10 @@ def apply_fixed_schedules( config[k] = v extracted_tasks = ms.extract_task_from_relay( - relay_mod, target, params, te_filter_func=te_filter_func, pass_config=config + relay_mod, + target, + params, + tir_converter=tir_converter, ) database = ms.database.MemoryDatabase() for task in extracted_tasks: diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 447fb56637ef..20eccc30a113 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -24,14 +24,12 @@ from tvm.ir import IRModule from tvm.ir.transform import PassContext -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.runtime import Module, NDArray, vm from tvm.target import Target from tvm.te import Tensor, create_prim_func from tvm.tir import PrimFunc, Schedule from . import default_config -from .apply_history_best import ApplyHistoryBest from .builder import Builder from .cost_model import CostModel from .database import Database, TuningRecord @@ -43,7 +41,7 @@ from .runner import Runner from .schedule_rule import ScheduleRule from .search_strategy import EvolutionarySearch, ReplayFunc, ReplayTrace -from .space_generator import SpaceGenerator +from .space_generator import PostOrderApply, SpaceGenerator from .task_scheduler import GradientBased, RoundRobin from .tune_context import TuneContext from .utils import autotvm_silencer, batch_parameterize_config @@ -461,7 +459,7 @@ def _f_block_filter(block, target_names) -> bool: mutator_probs=mutator_probs, num_threads=num_threads, ) - with Profiler.timeit("ApplyHistoryBest"): + with Profiler.timeit("PostTuningCompilation"): bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1) if not bests: return None @@ -591,6 +589,7 @@ def tune_relay( """ # pylint: disable=import-outside-toplevel from tvm import relay + from .relay_integration import extract_task_from_relay # pylint: disable=protected-access, enable=import-outside-toplevel @@ -615,13 +614,14 @@ def tune_relay( num_threads=num_threads, ) relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend] - with Profiler.timeit("ApplyHistoryBest"): - with target, autotvm_silencer(), ApplyHistoryBest(database): + with Profiler.timeit("PostTuningCompilation"): + with target, autotvm_silencer(), database: with PassContext( opt_level=3, config={ "relay.backend.use_meta_schedule": True, "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda", + "relay.backend.tir_converter": "default", }, ): return relay_build(mod, target=target, params=params) diff --git a/src/meta_schedule/apply_history_best.cc b/src/meta_schedule/apply_history_best.cc deleted file mode 100644 index 62db29306777..000000000000 --- a/src/meta_schedule/apply_history_best.cc +++ /dev/null @@ -1,165 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include - -#include "./utils.h" - -namespace tvm { -namespace meta_schedule { - -/**************** Utility functions ****************/ - -template -Optional GetOnlyOneFunctionCommon(const IRModule& mod, Callback on_found) { - if (mod->functions.size() != 1) { - return NullOpt; - } - for (const auto& kv : mod->functions) { - const BaseFunc& func = kv.second; - if (!func->IsInstance()) { - return NullOpt; - } else { - return on_found(kv); - } - } - return NullOpt; -} - -template -Optional GetOnlyOneFunctionKey(const IRModule& mod) { - return GetOnlyOneFunctionCommon(mod, [](auto kv) { return kv.first; }); -} - -template -Optional GetOnlyOneFunction(const IRModule& mod) { - return GetOnlyOneFunctionCommon( - mod, [](auto kv) { return Downcast(kv.second); }); -} - -template -bool HasOnlyOneFunction(const IRModule& mod) { - return GetOnlyOneFunction(mod).defined(); -} - -/**************** Context Manager ****************/ - -class ApplyHistoryBestInternal { - public: - static void EnterScope(ApplyHistoryBest ctx) { ctx.EnterWithScope(); } - static void ExitScope(ApplyHistoryBest ctx) { ctx.ExitWithScope(); } -}; - -struct ApplyHistoryBestThreadLocalEntry { - Optional ctx; -}; - -using ApplyHistoryBestThreadLocalStore = dmlc::ThreadLocalStore; - -Optional ApplyHistoryBest::Current() { - return ApplyHistoryBestThreadLocalStore::Get()->ctx; -} - -void ApplyHistoryBest::EnterWithScope() { - Optional& ctx = ApplyHistoryBestThreadLocalStore::Get()->ctx; - CHECK(!ctx.defined()) << "ValueError: Nested ApplyHistoryBest context managers are not allowed"; - ctx = *this; -} - -void ApplyHistoryBest::ExitWithScope() { - Optional& ctx = ApplyHistoryBestThreadLocalStore::Get()->ctx; - ICHECK(ctx.defined()); - ctx = NullOpt; -} - -/**************** ApplyHistoryBest ****************/ - -ApplyHistoryBest::ApplyHistoryBest(Database database, - ApplyHistoryBestNode::FTEFilterFunc te_filter_func, - PackedFunc logging_func) { - ObjectPtr n = make_object(); - n->database = database; - n->te_filter_func = te_filter_func; - n->logging_func = logging_func; - if (te_filter_func == nullptr) { - n->te_filter_func = DefaultTaskFilter; - } - data_ = n; -} - -Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, - Target target, Optional> dispatched, - FTakeTuningRecord f_take_tuning_record, - FDirectDispatch f_direct_dispatch) { - ICHECK(dispatched.defined()); - ICHECK_EQ(dispatched.value().size(), 1); - ICHECK(HasOnlyOneFunction(mod)) << mod; - IRModule prim_mod = dispatched.value()[0]; - ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; - - // Keep the original func name to be returned later. - GlobalVar gv = GetOnlyOneFunctionKey(prim_mod).value(); - - // Unify func name to make sure it can be found in database - const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); - ICHECK(parse_mod_func) << "Parse mod function not defined!"; - prim_mod = (*parse_mod_func)(prim_mod); - - if (f_direct_dispatch != nullptr) { - Optional mod = f_direct_dispatch(prim_mod); - if (mod.defined()) { - TVM_PY_LOG(INFO, logging_func) << "Direct dispatch applied for workload: " << task_name; - return mod.value(); - } - } - if (database->HasWorkload(prim_mod)) { - Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); - if (records.size() == 1) { - if (f_take_tuning_record != nullptr) { - f_take_tuning_record(records[0]); - } - tir::Schedule sch = - tir::Schedule::Traced(records[0]->workload->mod, /*seed=*/-1, /*debug_mask=*/0, - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); - records[0]->trace->ApplyToSchedule(sch, false); - tir::PrimFunc func = GetOnlyOneFunction(sch->mod()).value(); - // Make sure we return the updated PrimFunc paired with the original func name. - return IRModule({{gv, func}}); - } - } - TVM_PY_LOG(WARNING, logging_func) << "Cannot find workload: " << task_name; - return NullOpt; -} - -TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); -TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") - .set_body_typed([](Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func, - PackedFunc logging_func) -> ApplyHistoryBest { - return ApplyHistoryBest(database, te_filter_func, logging_func); - }); -TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestEnterScope") - .set_body_typed(ApplyHistoryBestInternal::EnterScope); -TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestExitScope") - .set_body_typed(ApplyHistoryBestInternal::ExitScope); -TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestCurrent") - .set_body_typed(ApplyHistoryBest::Current); -TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestQuery") - .set_body_method(&ApplyHistoryBestNode::Query); - -} // namespace meta_schedule -} // namespace tvm diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 4e180c4fab61..fedd2aa35278 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -154,6 +154,59 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w return TuningRecord(trace, workload, run_secs, target, args_info); } +/******** Database ********/ + +Optional DatabaseNode::QueryTuningRecord(IRModule mod, Target target) { + if (!this->HasWorkload(mod)) { + return NullOpt; + } + Array records = this->GetTopK(this->CommitWorkload(mod), 1); + if (records.empty()) { + return NullOpt; + } + ICHECK_EQ(records.size(), 1); + return records[0]; +} + +Optional DatabaseNode::QuerySchedule(IRModule mod, Target target) { + if (Optional opt_record = this->QueryTuningRecord(mod, target)) { + TuningRecord record = opt_record.value(); + tir::Schedule sch = + tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + record->trace->ApplyToSchedule(sch, false); + return sch; + } else { + return NullOpt; + } +} + +Optional DatabaseNode::QueryIRModule(IRModule mod, Target target) { + if (Optional opt_sch = this->QuerySchedule(mod, target)) { + return opt_sch.value()->mod(); + } else { + return NullOpt; + } +} + +std::vector* ThreadLocalDatabases() { + static thread_local std::vector tls; + return &tls; +} + +void Database::EnterWithScope() { ThreadLocalDatabases()->push_back(*this); } + +void Database::ExitWithScope() { ThreadLocalDatabases()->pop_back(); } + +Optional Database::Current() { + std::vector* tls = ThreadLocalDatabases(); + if (tls->empty()) { + return NullOpt; + } else { + return tls->back(); + } +} + /******** PyDatabase ********/ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, @@ -194,6 +247,11 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate") TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") .set_body_method(&TuningRecordNode::AsJSON); TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope") + .set_body_method(&Database::EnterWithScope); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope") + .set_body_method(&Database::ExitWithScope); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload") .set_body_method(&DatabaseNode::HasWorkload); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") @@ -205,6 +263,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK") TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords") .set_body_method(&DatabaseNode::GetAllTuningRecords); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord") + .set_body_method(&DatabaseNode::QueryTuningRecord); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule") + .set_body_method(&DatabaseNode::QuerySchedule); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule") + .set_body_method(&DatabaseNode::QueryIRModule); TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index 3406f82eb1f0..ec04361f51ec 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -38,67 +38,6 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, data_ = n; } -Optional DefaultTaskFilterImpl(const Array& args, - const Array& constants, - bool allow_extern_op) { - using namespace ::tvm::te; - std::vector stack; - std::unordered_set visited; - for (const Tensor& v : args) { - for (const PrimExpr& e : v->shape) { - // Dynamic shape is not supported for now - if (!e->IsInstance()) { - return NullOpt; - } - } - if (!visited.count(v.get())) { - visited.insert(v.get()); - stack.push_back(v); - } - } - while (!stack.empty()) { - Tensor tensor = stack.back(); - stack.pop_back(); - if (tensor->op->IsInstance()) { - // do nothing - } else if (tensor->op->IsInstance() || - (allow_extern_op && tensor->op->IsInstance())) { - Array inputs = tensor->op->InputTensors(); - for (const Tensor& v : inputs) { - if (!visited.count(v.get())) { - visited.insert(v.get()); - stack.push_back(v); - } - } - } else { - return NullOpt; - } - } - PrimFunc func = te::CreatePrimFuncWithConstants(args, constants); - bool dynamic_loop_extent = false; - PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { - if (const auto* loop = obj.as()) { - if (!loop->extent->IsInstance()) { - dynamic_loop_extent = true; - } - } - }); - if (dynamic_loop_extent) { - return NullOpt; - } - return func; -} - -Optional DefaultTaskFilter(const Array& args, - const Array& constants) { - return DefaultTaskFilterImpl(args, constants, false); -} - -Optional DefaultTaskFilterAllowExtern(const Array& args, - const Array& constants) { - return DefaultTaskFilterImpl(args, constants, true); -} - TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, @@ -106,14 +45,5 @@ TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") return ExtractedTask(task_name, mod, target, dispatched, weight); }); -TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter") - .set_body_typed([](const Array& args, const Array& constants) { - return DefaultTaskFilter(args, constants); - }); - -TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilterAllowExtern") - .set_body_typed([](const Array& args, const Array& constants) { - return DefaultTaskFilterAllowExtern(args, constants); - }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 664a6a609e7f..db37935ec206 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -21,7 +21,6 @@ #include #include -#include #include #include #include diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 4f83b6eeed60..213841c621de 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -16,8 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - -#include #include #include #include @@ -32,13 +30,10 @@ namespace tvm { namespace relay { namespace backend { -Array ExtractTask( - IRModule mod, Target target, Map params, - meta_schedule::ApplyHistoryBestNode::FTEFilterFunc filter_func) { +Array ExtractTask(IRModule mod, Target target, + Map params) { using meta_schedule::ExtractedTask; - if (filter_func == nullptr) { - filter_func = tvm::meta_schedule::DefaultTaskFilter; - } + backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter(); backend::BindParamsInModule(mod, params); // is_vm=true for backward compatibility Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); @@ -48,7 +43,7 @@ Array ExtractTask( std::vector tasks; std::unordered_map cache; - PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &filter_func](const Expr& exp) { + PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, &tir_converter](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { @@ -62,13 +57,11 @@ Array ExtractTask( } auto [inputs_outputs, constants, fused_name] = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); - if (Optional prim_func = filter_func(inputs_outputs, constants)) { - GlobalVar prim_fn_var(fused_name); - IRModule relay_mod({{prim_fn_var, relay_func}}); - IRModule tir_mod({{prim_fn_var, prim_func.value()}}); - ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1); - tasks.push_back(extracted_task); - cache.emplace(cache_key, extracted_task); + if (Optional f = tir_converter(inputs_outputs, constants)) { + IRModule relay_mod({{GlobalVar(fused_name), relay_func}}); + ExtractedTask task(fused_name, relay_mod, target, {PrimFuncToIRModule(f.value())}, 1); + tasks.push_back(task); + cache.emplace(cache_key, task); } } }); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 5c79ed2070cc..8fa8610c0fca 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -548,6 +548,7 @@ TECompiler& TECompiler::Global() { TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.tir_converter", String); TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { return TECompiler::Global(); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 92cc6f8cfa46..0e2a3e270257 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include #include @@ -37,6 +37,7 @@ #include #include #include +#include #include #include @@ -61,16 +62,6 @@ TVM_REGISTER_NODE_TYPE(CachedFuncNode); TVM_REGISTER_NODE_TYPE(CCacheKeyNode); TVM_REGISTER_NODE_TYPE(CCacheValueNode); -void ExtractTransformLayout(const meta_schedule::TuningRecord& record) { - static tir::InstructionKind kind_transform_layout = tir::InstructionKind::Get("TransformLayout"); - for (const tir::Instruction& inst : record->trace->insts) { - if (inst->kind.same_as(kind_transform_layout)) { - ICHECK_EQ(inst->attrs.size(), 3); - relay::MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast(inst->attrs[2])); - } - } -} - LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { auto n = make_object(); n->outputs = std::move(outputs); @@ -317,11 +308,11 @@ class ScheduleBuilder : public ExprVisitor { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); if (backend::IsMetaScheduleEnabled()) { - meta_schedule_ctx_ = meta_schedule::ApplyHistoryBest::Current(); - CHECK(meta_schedule_ctx_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay " - "build, but no ApplyHistoryBest context is provided. "; + database_ = meta_schedule::Database::Current(); + CHECK(database_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay " + "build, but no `meta_schedule.Database` context is provided. "; } else { - meta_schedule_ctx_ = NullOpt; + database_ = NullOpt; } } @@ -359,32 +350,43 @@ class ScheduleBuilder : public ExprVisitor { schedule = Downcast(obj); } } - if (meta_schedule_ctx_) { + if (database_) { + using tvm::meta_schedule::TuningRecord; + using tvm::tir::IndexMap; + using tvm::tir::Instruction; + using tvm::tir::InstructionKind; + using tvm::tir::PrimFunc; + using tvm::tir::Schedule; + backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter(); Array te_args = Concat(fn_inputs, tensor_outs); Array constants; for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) { te_args.push_back(te_tensor); constants.push_back(const_node->data); } - - if (Optional tir_func = - meta_schedule_ctx_.value()->te_filter_func(te_args, constants)) { - IRModule relay_mod({{prim_fn_var, relay_func}}); - IRModule tir_mod({{prim_fn_var, tir_func.value()}}); - if (Optional opt_scheduled_mod = meta_schedule_ctx_.value()->Query( - /*task_name=*/prim_fn_var->name_hint, // - /*mod=*/relay_mod, // - /*target=*/target_, // - /*dispatched=*/Array{tir_mod}, // - /*f_take_tuning_record=*/ExtractTransformLayout)) { - IRModule scheduled_mod = - tir::transform::RemoveWeightLayoutRewriteBlock()(opt_scheduled_mod.value()); - ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1); - prim_func = Downcast(scheduled_mod->functions[prim_fn_var]); + if (Optional f = tir_converter(te_args, constants)) { + if (Optional opt_record = database_.value()->QueryTuningRecord( + /*mod=*/backend::PrimFuncToIRModule(f.value()), + /*target=*/target_)) { + static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout"); + TuningRecord record = opt_record.value(); + for (const Instruction& inst : record->trace->insts) { + if (inst->kind.same_as(kind_transform_layout)) { + ICHECK_EQ(inst->attrs.size(), 3); + MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast(inst->attrs[2])); + } + } + Schedule sch = Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, + tir::ScheduleErrorRenderLevel::kDetail); + record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); + IRModule mod = sch->mod(); + ICHECK_EQ(mod->functions.size(), 1); + mod = tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod)); + prim_func = Downcast(mod->Lookup("main")); } } } - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. + // Use TOPI schedule if user specified, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { if (anchor_op_.defined()) { auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->()); @@ -422,7 +424,7 @@ class ScheduleBuilder : public ExprVisitor { } int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && !meta_schedule_ctx_.defined() && op_pattern >= kCommReduce) { + if (!use_auto_scheduler_ && !database_.defined() && op_pattern >= kCommReduce) { ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) << "Cannot apply TOPI schedule to a primitive function with two complicated ops" << " anchor=" << anchor_op_ << " current=" << op; @@ -440,7 +442,7 @@ class ScheduleBuilder : public ExprVisitor { Attrs anchor_attrs_; int anchor_op_pattern_{0}; bool use_auto_scheduler_; - Optional meta_schedule_ctx_; + Optional database_; }; /*! diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 340986770e93..5cf7a5563d19 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -28,6 +28,9 @@ #include #include #include +#include + +#include "../../te/operation/create_primfunc.h" namespace tvm { namespace relay { @@ -368,6 +371,76 @@ void BindParamsInModule(IRModule mod, Map params) { BindParamsInModule(mod, params_tmp); } +/*! + * \brief A default TE compute to TIR compute. + * \param args The inputs/outputs of the TE compute graph. + * \param constants The constants bound to TIR + * \param allow_extern_op Whether to allow extern operation in TE. + * \return The TIR converted; NullOpt if not supported (dynamic shape) + */ +Optional DefaultTIRConverterImpl(const Array& args, + const Array& constants, + bool allow_extern_op) { + using namespace ::tvm::te; + std::vector stack; + std::unordered_set visited; + for (const Tensor& v : args) { + for (const PrimExpr& e : v->shape) { + // Dynamic shape is not supported for now + if (!e->IsInstance()) { + return NullOpt; + } + } + if (!visited.count(v.get())) { + visited.insert(v.get()); + stack.push_back(v); + } + } + while (!stack.empty()) { + Tensor tensor = stack.back(); + stack.pop_back(); + if (tensor->op->IsInstance()) { + // do nothing + } else if (tensor->op->IsInstance() || + (allow_extern_op && tensor->op->IsInstance())) { + Array inputs = tensor->op->InputTensors(); + for (const Tensor& v : inputs) { + if (!visited.count(v.get())) { + visited.insert(v.get()); + stack.push_back(v); + } + } + } else { + return NullOpt; + } + } + PrimFunc func = te::CreatePrimFuncWithConstants(args, constants); + bool dynamic_loop_extent = false; + tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { + if (const auto* loop = obj.as()) { + if (!loop->extent->IsInstance()) { + dynamic_loop_extent = true; + } + } + }); + if (dynamic_loop_extent) { + return NullOpt; + } + return func; +} + +TVM_REGISTER_GLOBAL("relay.backend.tir_converter.default") + .set_body_typed([](const Array& args, + const Array& constants) -> Optional { + return DefaultTIRConverterImpl(args, constants, false); + }); + +TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern") + .set_body_typed([](const Array& args, + const Array& constants) -> Optional { + return DefaultTIRConverterImpl(args, constants, true); + }); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 57c066131181..37ae9d803a35 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -558,6 +558,37 @@ inline bool IsMetaScheduleEnabled() { .value(); } +/*! + * \brief Method in TECompiler to convert TE compute to scheduleable TIR + * \param args The arguments of the TE compute + * \param constants The constants used in AllocateConst + * \return NullOpt if conversion fails; Otherwise the converted TIR + * \note This method could be further used as a task filtering mechanism in task extraction + */ +using FTECompilerTIRConverter = runtime::TypedPackedFunc< // + Optional( // + const Array& args, // + const Array& constants)>; + +/*! \brief Return a task filter for AutoTIR according to `relay.backend.tir_converter` */ +inline FTECompilerTIRConverter GetTIRConverter() { + String name = transform::PassContext::Current() + ->GetConfig("relay.backend.tir_converter", "default") + .value(); + const PackedFunc* f = runtime::Registry::Get("relay.backend.tir_converter." + name); + ICHECK(f != nullptr) << "IndexError: Cannot find TIR converter: " << name; + return FTECompilerTIRConverter(*f); +} + +/*! \brief Converts a PrimFunc to IRModule. */ +inline IRModule PrimFuncToIRModule(tir::PrimFunc f) { + f = WithAttrs(f, Map{ + {tvm::attr::kGlobalSymbol, String("main")}, + {tvm::tir::attr::kNoAlias, Bool(1)}, + }); + return IRModule({{GlobalVar("main"), f}}); +} + /*! * \brief Get the sequence of Relay optimization passes based on backend type. * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_meta_schedule_auto_tensorize.py index 3397eaabbef2..7227ef0c7b79 100644 --- a/tests/python/integration/test_meta_schedule_auto_tensorize.py +++ b/tests/python/integration/test_meta_schedule_auto_tensorize.py @@ -19,13 +19,12 @@ import numpy as np import pytest - import tvm import tvm.testing import tvm.topi.testing from tvm import meta_schedule as ms from tvm import relay -from tvm.meta_schedule import ApplyHistoryBest, postproc, schedule_rule +from tvm.meta_schedule import postproc, schedule_rule from tvm.meta_schedule.relay_integration import extract_task_from_relay from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base from tvm.meta_schedule.tune import tune_extracted_tasks @@ -176,12 +175,11 @@ def tune_and_test(relay_mod, data_np, weight_np, op_name, target, sch_rules, pos postprocs=lambda: postprocs, ) - with ApplyHistoryBest(database): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - lib = relay.build(relay_mod, target=target, params=params) + with database, tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(relay_mod, target=target, params=params) if "cascadelake" in target: asm = lib.lib.get_source("asm") @@ -267,12 +265,11 @@ def _test_bert_int8(target, sch_rules, postprocs): postprocs=lambda: postprocs, ) - with ApplyHistoryBest(database): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - lib = relay.build(relay_mod, target=target, params=params) + with database, tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(relay_mod, target=target, params=params) dev = tvm.device("cuda" if "nvidia" in target else target, 0) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index 8e299dc935d5..c741ecb59ae0 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -19,20 +19,18 @@ import json import os import re -from io import StringIO from contextlib import redirect_stderr +from io import StringIO import numpy as np - import tvm import tvm.relay import tvm.testing from tvm import meta_schedule as ms from tvm import relay -from tvm.relay.backend import Executor, Runtime from tvm.contrib import utils from tvm.meta_schedule.testing.utils import apply_fixed_schedules - +from tvm.relay.backend import Executor, Runtime INPUT_SHAPE = (1, 3, 16, 16) @@ -421,13 +419,12 @@ def schedule_fn(task, sch): database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) with StringIO() as stderr_buf, redirect_stderr(stderr_buf): - with ms.ApplyHistoryBest(database): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - executor = Executor("graph", {"link-params": link_params}) - lib = relay.build(relay_mod, target=target, executor=executor) + with database, tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + executor = Executor("graph", {"link-params": link_params}) + lib = relay.build(relay_mod, target=target, executor=executor) # Workload look up should succeed. This does not work when the test is invoked from pytest. assert not "Cannot find workload" in stderr_buf.getvalue() diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index afce19a590e3..69522831ee55 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Integration test for MetaSchedule""" -from typing import Optional import numpy as np import pytest import tvm @@ -23,11 +22,10 @@ from tvm import IRModule from tvm import meta_schedule as ms from tvm import relay, te, tir +from tvm._ffi import register_func from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base from tvm.script import tir as T -from tvm.target import Target -from tvm.tir import Schedule # pylint: disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name @@ -58,10 +56,6 @@ def _has_torch(): requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") -def test_meta_schedule_apply_history_best_no_current(): - assert ms.ApplyHistoryBest.current() is None - - def test_meta_schedule_dynamic_loop_extent(): a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32") b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC") @@ -125,7 +119,7 @@ def test_meta_schedule_integration_extract_from_bert_base(): 12, [[64, 768], [3072, 768], [64, 3072]], ), - "fused_subtract_add_sqrt_divide_multiply_add": ( + "fused_subtract_add_rsqrt_multiply_multiply_add": ( 25, [[1, 64, 768], [1, 64, 1], [1, 64, 1], [768], [768], [1, 64, 768]], ), @@ -206,7 +200,8 @@ def test_meta_schedule_integration_extract_from_bert_base(): @requires_torch def test_meta_schedule_integration_extract_from_resnet_with_filter_func(): - def filter_func(args) -> bool: + @register_func("relay.backend.tir_converter.remove_purely_spatial", override=True) + def filter_func(args, _) -> bool: from tvm.te import create_prim_func # pylint: disable=import-outside-toplevel has_complex_op = False @@ -236,7 +231,7 @@ def traverse(t): mod, target="llvm", params=params, - te_filter_func=filter_func, + tir_converter="remove_purely_spatial", ) expected_task_names = [ "fused_" + s @@ -267,53 +262,6 @@ def traverse(t): assert t.task_name in expected_task_names, t.task_name -@requires_torch -def test_meta_schedule_integration_apply_history_best(): - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - database = ms.database.MemoryDatabase() - env = ms.ApplyHistoryBest(database) - target = Target("llvm") - workload = database.commit_workload(MockModule) - database.commit_tuning_record( - ms.database.TuningRecord( - trace=Schedule(MockModule).trace, - workload=workload, - run_secs=[1.0], - target=target, - args_info=[], - ) - ) - mod = env.query( - task_name="mock-task", - mod=mod, - target=target, - dispatched=[MockModule], - ) - assert tvm.ir.structural_equal(mod, workload.mod) - - -@requires_torch -def test_meta_schedule_integration_apply_history_best_direct_dispatch(): - def direct_dispatch(mod: IRModule) -> Optional[IRModule]: - if tvm.ir.structural_equal(mod, MockModule): - return MockModule - return None - - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - database = ms.database.MemoryDatabase() - env = ms.ApplyHistoryBest(database) - target = Target("llvm") - workload = database.commit_workload(MockModule) - mod = env.query( - task_name="mock-task-direct-dispatch", - mod=mod, - target=target, - dispatched=[MockModule], - f_direct_dispatch=direct_dispatch, - ) - assert tvm.ir.structural_equal(mod, workload.mod) - - @pytest.mark.skip("Too slow on CI") def extract_task_qbert(): mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128) diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py index b7d012ca04d6..177001781179 100644 --- a/tests/python/unittest/test_meta_schedule_multi_anchor.py +++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py @@ -70,7 +70,7 @@ def schedule_fn(task, sch): return False database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) - with ms.ApplyHistoryBest(database): + with database: with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, diff --git a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py index 058012cb643a..939851a65731 100644 --- a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py +++ b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py @@ -19,7 +19,6 @@ import tvm.testing import tvm.topi.testing from tvm import autotvm, relay, te -from tvm.meta_schedule import ApplyHistoryBest from tvm.meta_schedule.testing.utils import apply_fixed_schedules from tvm.relay.testing.temp_op_attr import TempOpAttr from tvm.script import tir as T @@ -152,17 +151,16 @@ def schedule_fn(task, sch): target, params, schedule_fn, - te_filter_func="meta_schedule.DefaultTaskFilterAllowExtern", + tir_converter="allow_extern", ) - with ApplyHistoryBest( - database, - te_filter_func="meta_schedule.DefaultTaskFilterAllowExtern", + with database, tvm.transform.PassContext( + opt_level=3, + config={ + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "allow_extern", + }, ): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - lib = relay.build(relay_mod, target=target, params=params) + lib = relay.build(relay_mod, target=target, params=params) dev = tvm.device(target, 0) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 7d85b8757ae2..bc37fed7d691 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -245,12 +245,11 @@ def print_results(self) -> None: database.commit_workload(tvmgen_default_fused_layout_transform_1) database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc) - with ms.ApplyHistoryBest(database): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - rt_mod1 = relay.build(mod, target=target, params=params) + with database, tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + rt_mod1 = relay.build(mod, target=target, params=params) # Compile without meta-schedule for correctness check with tvm.transform.PassContext(opt_level=0): @@ -307,12 +306,11 @@ def test_meta_schedule_relay_lowering(): args_info=[], ) ) - with ms.ApplyHistoryBest(database): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - rt_mod1 = relay.build(mod, target=target, params=params) + with database, tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + rt_mod1 = relay.build(mod, target=target, params=params) # Compile without meta-schedule for correctness check with tvm.transform.PassContext(opt_level=0): @@ -472,24 +470,23 @@ def schedule_fn(task, sch): database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) - with ms.ApplyHistoryBest(database): - with tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - # pylint: disable=W0105 - """ - The log should say - Warning: Cannot find workload: tvmgen_default_fused_expand_dims - Warning: Cannot find workload: tvmgen_default_fused_cast - Warning: Cannot find workload: tvmgen_default_fused_cast_1 - Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul - - This means batch matmul and others are scheduled by TE, and dense (the one not warned) - is found in the meta schedule tuning database during ApplyHistoryBest - """ - # pylint: enable=W0105 - lib = relay.build(relay_mod, target=target, params=params) + with database, tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + # pylint: disable=W0105 + """ + The log should say + Warning: Cannot find workload: tvmgen_default_fused_expand_dims + Warning: Cannot find workload: tvmgen_default_fused_cast + Warning: Cannot find workload: tvmgen_default_fused_cast_1 + Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul + + This means batch matmul and others are scheduled by TE, and dense (the one not warned) + is found in the meta schedule tuning database during compilation + """ + # pylint: enable=W0105 + lib = relay.build(relay_mod, target=target, params=params) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))