diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 428a2e80f4dd..7a7599b0a4f8 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -82,6 +82,8 @@ class TuneContextNode : public runtime::Object { v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); v->Visit("is_stopped", &is_stopped); + v->Visit("builder_results", &builder_results); + v->Visit("runner_futures", &runner_futures); v->Visit("measure_candidates", &measure_candidates); } diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 5532c472e638..178239e7def1 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -20,13 +20,15 @@ from tvm._ffi import register_object from tvm.ir import IRModule, transform -from tvm.relay import Any, Function as RelayFunc, vm +from tvm.relay import Any +from tvm.relay import Function as RelayFunc +from tvm.relay import vm from tvm.runtime import NDArray, Object from tvm.target import Target from tvm.tir import PrimFunc -from .database import Database from . import _ffi_api +from .database import Database @register_object("meta_schedule.ExtractedTask") diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index faf61f5de3e6..f429e417bac7 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -15,16 +15,16 @@ # specific language governing permissions and limitations # under the License. """User-facing Tuning API""" - +# pylint: disable=import-outside-toplevel import logging import os.path -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import tvm -from tvm import relay -from tvm._ffi import register_func -from tvm.ir import IRModule, structural_equal, structural_hash +from tvm._ffi.registry import register_func +from tvm.ir import IRModule, structural_hash from tvm.relay import Function as RelayFunc +from tvm.relay import build as relay_build from tvm.runtime import Module, NDArray from tvm.target import Target from tvm.te import Tensor, create_prim_func @@ -34,7 +34,7 @@ from .cost_model import CostModel, XGBModel from .database import Database, JSONDatabase, TuningRecord from .feature_extractor import PerStoreFeature -from .integration import ApplyHistoryBest, extract_task_from_relay +from .integration import ApplyHistoryBest, ExtractedTask, extract_task_from_relay from .measure_callback import MeasureCallback from .mutator import Mutator from .postproc import Postproc @@ -49,7 +49,6 @@ from .task_scheduler import RoundRobin, TaskScheduler from .tune_context import TuneContext - logger = logging.getLogger(__name__) # pylint: disable=invalid-name SearchStrategyConfig = Union[ @@ -79,9 +78,7 @@ class DefaultLLVM: @staticmethod def _sch_rules() -> List[ScheduleRule]: - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - schedule_rule as M, - ) + from tvm.meta_schedule import schedule_rule as M return [ M.AutoInline( @@ -117,9 +114,7 @@ def _sch_rules() -> List[ScheduleRule]: @staticmethod def _postproc() -> List[Postproc]: - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - postproc as M, - ) + from tvm.meta_schedule import postproc as M return [ M.DisallowDynamicLoop(), @@ -129,9 +124,7 @@ def _postproc() -> List[Postproc]: @staticmethod def _mutator_probs() -> Dict[Mutator, float]: - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - mutator as M, - ) + from tvm.meta_schedule import mutator as M return { M.MutateTileSize(): 0.9, @@ -146,9 +139,7 @@ class DefaultCUDA: @staticmethod def _sch_rules() -> List[ScheduleRule]: - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - schedule_rule as M, - ) + from tvm.meta_schedule import schedule_rule as M return [ M.MultiLevelTiling( @@ -170,7 +161,6 @@ def _sch_rules() -> List[ScheduleRule]: M.AutoInline( into_producer=True, into_consumer=True, - # into_cache_only=False, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, @@ -188,9 +178,7 @@ def _sch_rules() -> List[ScheduleRule]: @staticmethod def _postproc() -> List[Postproc]: - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - postproc as M, - ) + from tvm.meta_schedule import postproc as M return [ M.DisallowDynamicLoop(), @@ -203,12 +191,10 @@ def _postproc() -> List[Postproc]: @staticmethod def _mutator_probs() -> Dict[Mutator, float]: - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - mutator as M, - ) + from tvm.meta_schedule import mutator as M return { - # M.MutateTileSize(): 0.9, + M.MutateTileSize(): 0.9, M.MutateUnroll(): 0.1, } @@ -280,9 +266,7 @@ def _callbacks( measure_callbacks: Optional[List[MeasureCallback]], ) -> List[MeasureCallback]: if measure_callbacks is None: - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - measure_callback as M, - ) + from tvm.meta_schedule import measure_callback as M return [ M.AddToDatabase(), @@ -468,8 +452,6 @@ def tune_tir( The target to tune for. config : SearchStrategyConfig The search strategy config. - task_name : str - The name of the task. work_dir : Optional[str] The working directory to save intermediate results. builder : Optional[Builder] @@ -604,14 +586,44 @@ def tune_te( ) -def tune_relay( - mod: Union[RelayFunc, IRModule], - target: Union[str, Target], +def deduplicate_extracted_tasks( + extracted_tasks: List[ExtractedTask], +) -> Tuple[List[ExtractedTask], List[int]]: + """Remove duplicate extraced tasks. + + Parameters + ---------- + extracted_tasks : List[ExtractedTask] + The list of extraced tasks. + + Returns + ------- + tasks : Tuple[List[ExtractedTask], List[int]] + A tuple containing the deduplicated extraced tasks and the count for each task. + """ + hash2idx: Dict[int, int] = {} + dedup: List[ExtractedTask] = [] + count: List[int] = [] + + for task in extracted_tasks: + assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" + mod = Parse._mod(task.dispatched[0]) + shash = structural_hash(mod) + if shash in hash2idx: + count[hash2idx[shash]] += 1 + else: + hash2idx[shash] = len(dedup) + dedup.append(task) + count.append(1) + return dedup, count + + +def tune_extracted_tasks( + extracted_tasks: List[ExtractedTask], + target: Target, config: SearchStrategyConfig, work_dir: str, *, - params: Optional[Dict[str, NDArray]] = None, - task_name: str = "main", builder: Optional[Builder] = None, runner: Optional[Runner] = None, database: Optional[Database] = None, @@ -623,21 +635,17 @@ def tune_relay( postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, num_threads: Optional[int] = None, -) -> Module: - """Tune a TIR IRModule with a given target. +) -> Database: + """Tune extracted tasks with a given target. Parameters ---------- - mod : Union[RelayFunc, IRModule] - The module to tune. + extracted_tasks : List[ExtractedTask] + The list of extraced tasks. target : Union[str, Target] The target to tune for. config : SearchStrategyConfig The search strategy config. - params : Optional[Dict[str, tvm.runtime.NDArray]] - The associated parameters of the program - task_name : str - The name of the task. work_dir : Optional[str] The working directory to save intermediate results. builder : Optional[Builder] @@ -646,26 +654,37 @@ def tune_relay( The runner to use. database : Optional[Database] The database to use. + cost_model : Optional[CostModel] + The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. - f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] - The function to create TuneContext. - f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] - The function to create TaskScheduler. + task_scheduler : Optional[TaskScheduler] + The task scheduler to use. + space : Optional[FnSpaceGenerator] + The space generator to use. + sch_rules : Optional[FnScheduleRule] + The search rules to use. + postprocs : Optional[FnPostproc] + The postprocessors to use. + mutator_probs : Optional[FnMutatorProb] + The probability distribution to use different mutators. + num_threads : Optional[int] + The number of threads to use. Returns ------- - lib : Module - The built runtime module for the given relay workload. - """ + database : Database + The database containing all the tuning results. - logger.info("Working directory: %s", work_dir) - extracted_tasks = extract_task_from_relay(mod, target, params) + """ + # deduplication + logger.info("Before task deduplication: %d tasks", len(extracted_tasks)) + extracted_tasks, _ = deduplicate_extracted_tasks(extracted_tasks) + logger.info("After task deduplication: %d tasks", len(extracted_tasks)) # pylint: disable=protected-access - tune_contexts = [] target = Parse._target(target) - database = Parse._database(database, task_name, work_dir) # parse the tuning contexts + tune_contexts = [] for task in extracted_tasks: assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" tune_contexts.append( @@ -682,27 +701,11 @@ def tune_relay( num_threads=num_threads, ) ) - # deduplication - logger.info("Before task deduplication: %d tasks", len(tune_contexts)) - tasks: List[TuneContext] = [] - hashs: List[int] = [] - for i, task in enumerate(tune_contexts): - struct_hash: int = structural_hash(task.mod) - flag: bool = False - if struct_hash in hashs: - for other_task in tune_contexts[i + 1 :]: - if structural_equal(task.mod, other_task.mod): - flag = True - break - if not flag: - tasks.append(task) - hashs.append(struct_hash) - logger.info("After task deduplication: %d tasks", len(tasks)) - # parse the task scheduler + database = Parse._database(database, "default", work_dir) task_scheduler = Parse._task_scheduler( task_scheduler, - tasks, + tune_contexts, builder=Parse._builder(builder), runner=Parse._runner(runner), database=database, @@ -711,9 +714,85 @@ def tune_relay( ) # pylint: enable=protected-access task_scheduler.tune() + return database + + +def tune_relay( + mod: Union[RelayFunc, IRModule], + target: Union[str, Target], + config: SearchStrategyConfig, + work_dir: str, + *, + params: Optional[Dict[str, NDArray]] = None, + builder: Optional[Builder] = None, + runner: Optional[Runner] = None, + database: Optional[Database] = None, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + task_scheduler: Optional[TaskScheduler] = None, + space: Optional[FnSpaceGenerator] = None, + sch_rules: Optional[FnScheduleRule] = None, + postprocs: Optional[FnPostproc] = None, + mutator_probs: Optional[FnMutatorProb] = None, + num_threads: Optional[int] = None, +) -> Module: + """Tune a TIR IRModule with a given target. + + Parameters + ---------- + mod : Union[RelayFunc, IRModule] + The module to tune. + target : Union[str, Target] + The target to tune for. + config : SearchStrategyConfig + The search strategy config. + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + task_name : str + The name of the task. + work_dir : Optional[str] + The working directory to save intermediate results. + builder : Optional[Builder] + The builder to use. + runner : Optional[Runner] + The runner to use. + database : Optional[Database] + The database to use. + measure_callbacks : Optional[List[MeasureCallback]] + The callbacks used during tuning. + f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] + The function to create TuneContext. + f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] + The function to create TaskScheduler. + + Returns + ------- + lib : Module + The built runtime module for the given relay workload. + """ + + logger.info("Working directory: %s", work_dir) + extracted_tasks = extract_task_from_relay(mod, target, params) + database = tune_extracted_tasks( + extracted_tasks, + target, + config, + work_dir, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + sch_rules=sch_rules, + postprocs=postprocs, + mutator_probs=mutator_probs, + num_threads=num_threads, + ) with ApplyHistoryBest(database): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): - return relay.build(mod, target=target, params=params) + return relay_build(mod, target=target, params=params) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 277fa2407bd1..efa1183814c8 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -112,7 +112,6 @@ def _sch_rules(): M.AutoInline( into_producer=False, into_consumer=True, - # into_cache_only=False, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, @@ -139,7 +138,6 @@ def _sch_rules(): M.AutoInline( into_producer=True, into_consumer=True, - # into_cache_only=True, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, @@ -161,10 +159,10 @@ def _postproc(): ) return [ - # M.RewriteCooperativeFetch(), + M.RewriteCooperativeFetch(), M.RewriteParallelVectorizeUnroll(), M.RewriteReductionBlock(), - # M.RewriteTensorCore(), + M.RewriteTensorCore(), M.VerifyGPUCode(), ]