From e3fbb797a88308e4ce3d671939a83084ae1826b8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 31 Mar 2022 01:34:23 -0700 Subject: [PATCH] [MetaSchedule] Add Gradient Based Task Scheduler --- include/tvm/meta_schedule/search_strategy.h | 12 +- include/tvm/meta_schedule/task_scheduler.h | 94 ++++---- include/tvm/meta_schedule/tune_context.h | 4 +- include/tvm/support/random_engine.h | 18 +- include/tvm/tir/schedule/schedule.h | 2 +- .../search_strategy/evolutionary_search.py | 13 +- .../search_strategy/replay_func.py | 13 +- .../search_strategy/replay_trace.py | 16 +- .../meta_schedule/task_scheduler/__init__.py | 1 + .../task_scheduler/gradient_based.py | 93 +++++++ .../task_scheduler/round_robin.py | 15 +- .../task_scheduler/task_scheduler.py | 86 +++---- .../testing/tune_relay_meta_schedule.py | 8 +- .../testing/tune_te_meta_schedule.py | 3 +- python/tvm/meta_schedule/tune.py | 14 +- python/tvm/meta_schedule/utils.py | 2 +- .../measure_callback/echo_statistics.cc | 10 +- .../search_strategy/evolutionary_search.cc | 16 +- .../search_strategy/replay_func.cc | 12 +- .../search_strategy/replay_trace.cc | 12 +- .../task_scheduler/gradient_based.cc | 228 ++++++++++++++++++ .../task_scheduler/round_robin.cc | 10 +- .../task_scheduler/task_scheduler.cc | 134 +++++----- src/meta_schedule/tune_context.cc | 5 +- src/meta_schedule/utils.h | 23 ++ src/support/table_printer.h | 154 ++++++++++++ src/tir/schedule/concrete_schedule.cc | 3 - src/tir/schedule/concrete_schedule.h | 2 +- .../test_meta_schedule_measure_callback.py | 9 +- .../test_meta_schedule_search_strategy.py | 16 +- .../test_meta_schedule_task_scheduler.py | 138 ++++++++--- .../unittest/test_meta_schedule_tune_relay.py | 23 +- .../unittest/test_meta_schedule_tune_te.py | 3 +- .../unittest/test_meta_schedule_tune_tir.py | 25 +- 34 files changed, 894 insertions(+), 323 deletions(-) create mode 100644 python/tvm/meta_schedule/task_scheduler/gradient_based.py create mode 100644 src/meta_schedule/task_scheduler/gradient_based.cc create mode 100644 src/support/table_printer.h diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 0a4024915def..6895673a04cc 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -252,21 +252,21 @@ class SearchStrategy : public runtime::ObjectRef { /*! * \brief Constructor of replay trace search strategy. * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. - * \param num_trials_total The total number of trials for trace replaying. + * \param max_trials_per_task The total number of trials for trace replaying. */ - TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task); /*! * \brief Constructor of replay func search strategy. * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. - * \param num_trials_total The total number of trials for func replaying. + * \param max_trials_per_task The total number of trials for func replaying. */ - TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); + TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int max_trials_per_task); /*! * \brief Constructor of evolutionary search strategy. * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. - * \param num_trials_total The total number of trials for evolutionary search. + * \param max_trials_per_task The total number of trials for evolutionary search. * \param population_size The initial sample population. * \param init_measured_ratio The ratio of measures samples in initial population. * \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling. @@ -276,7 +276,7 @@ class SearchStrategy : public runtime::ObjectRef { * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score. */ TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, // - int num_trials_total, // + int max_trials_per_task, // int population_size, // double init_measured_ratio, // int init_min_unmeasured, // diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index ddd6f4c4815f..81d340d33e6b 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -75,10 +75,14 @@ class TaskSchedulerNode : public runtime::Object { Runner runner{nullptr}; /*! \brief The database of the scheduler. */ Database database{nullptr}; + /*! \brief The maximum number of trials allowed. */ + int max_trials; /*! \brief The cost model of the scheduler. */ Optional cost_model; /*! \brief The list of measure callbacks of the scheduler. */ Array measure_callbacks; + /*! \brief The number of trials already conducted. */ + int num_trials_already; /*! \brief The default destructor. */ virtual ~TaskSchedulerNode() = default; @@ -88,8 +92,10 @@ class TaskSchedulerNode : public runtime::Object { v->Visit("builder", &builder); v->Visit("runner", &runner); v->Visit("database", &database); + v->Visit("max_trials", &max_trials); v->Visit("cost_model", &cost_model); v->Visit("measure_callbacks", &measure_callbacks); + v->Visit("num_trials_already", &num_trials_already); } /*! \brief Auto-tuning. */ @@ -102,23 +108,16 @@ class TaskSchedulerNode : public runtime::Object { virtual void InitializeTask(int task_id); /*! - * \brief Set specific task to be stopped. - * \param task_id The task id to be stopped. - */ - virtual void SetTaskStopped(int task_id); - - /*! - * \brief Check whether the task is running. + * \brief Touch the task and update its status * \param task_id The task id to be checked. - * \return Whether the task is running. */ - virtual bool IsTaskRunning(int task_id); + virtual void TouchTask(int task_id); /*! * \brief Wait until the task is finished. * \param task_id The task id to be joined. */ - virtual void JoinRunningTask(int task_id); + virtual Array JoinRunningTask(int task_id); /*! * \brief Fetch the next task id. @@ -142,23 +141,17 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { using FInitializeTask = runtime::TypedPackedFunc; /*! - * \brief The function type of `SetTaskStopped` method. - * \param task_id The task id to be stopped. - */ - using FSetTaskStopped = runtime::TypedPackedFunc; - - /*! - * \brief The function type of `IsTaskRunning` method. + * \brief The function type of `TouchTask` method. * \param task_id The task id to be checked. * \return Whether the task is running. */ - using FIsTaskRunning = runtime::TypedPackedFunc; + using FTouchTask = runtime::TypedPackedFunc; /*! * \brief The function type of `JoinRunningTask` method. * \param task_id The task id to be joined. */ - using FJoinRunningTask = runtime::TypedPackedFunc; + using FJoinRunningTask = runtime::TypedPackedFunc(int)>; /*! * \brief The function type of `NextTaskId` method. @@ -170,10 +163,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { FTune f_tune; /*! \brief The packed function to the `InitializeTask` function. */ FInitializeTask f_initialize_task; - /*! \brief The packed function to the `SetTaskStopped` function. */ - FSetTaskStopped f_set_task_stopped; - /*! \brief The packed function to the `IsTaskRunning` function. */ - FIsTaskRunning f_is_task_running; + /*! \brief The packed function to the `TouchTask` function. */ + FTouchTask f_touch_task; /*! \brief The packed function to the `JoinRunningTask` function. */ FJoinRunningTask f_join_running_task; /*! \brief The packed function to the `NextTaskId` function. */ @@ -182,8 +173,7 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { void VisitAttrs(tvm::AttrVisitor* v) { // `f_tune` is not visited // `f_initialize_task` is not visited - // `f_set_task_stopped` is not visited - // `f_is_task_running` is not visited + // `f_touch_task` is not visited // `f_join_running_task` is not visited // `f_next_task_id` is not visited } @@ -204,23 +194,15 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { } } - void SetTaskStopped(int task_id) final { - if (f_set_task_stopped == nullptr) { - TaskSchedulerNode::SetTaskStopped(task_id); - } else { - f_set_task_stopped(task_id); - } - } - - bool IsTaskRunning(int task_id) final { - if (f_is_task_running == nullptr) { - return TaskSchedulerNode::IsTaskRunning(task_id); + void TouchTask(int task_id) final { + if (f_touch_task == nullptr) { + return TaskSchedulerNode::TouchTask(task_id); } else { - return f_is_task_running(task_id); + return f_touch_task(task_id); } } - void JoinRunningTask(int task_id) final { + Array JoinRunningTask(int task_id) final { if (f_join_running_task == nullptr) { return TaskSchedulerNode::JoinRunningTask(task_id); } else { @@ -249,6 +231,7 @@ class TaskScheduler : public runtime::ObjectRef { * \param builder The builder of the scheduler. * \param runner The runner of the scheduler. * \param database The database of the scheduler. + * \param max_trials The maximum number of trials. * \param cost_model The cost model of the scheduler. * \param measure_callbacks The measure callbacks of the scheduler. * \return The task scheduler created. @@ -257,20 +240,47 @@ class TaskScheduler : public runtime::ObjectRef { Builder builder, // Runner runner, // Database database, // + int max_trials, // Optional cost_model, // Optional> measure_callbacks); + /*! + * \brief Create a task scheduler that fetches tasks in a gradient based fashion. + * \param tasks The tasks to be tuned. + * \param task_weights The weights of each task. + * \param builder The builder of the scheduler. + * \param runner The runner of the scheduler. + * \param database The database of the scheduler. + * \param max_trials The maximum number of trials. + * \param cost_model The cost model of the scheduler. + * \param measure_callbacks The measure callbacks of the scheduler. + * \param alpha The parameter alpha to control gradient computation. + * \param window_size The parameter to control backward window size. + * \param seed The random seed. + * \return The task scheduler created. + */ + TVM_DLL static TaskScheduler GradientBased(Array tasks, + Array task_weights, // + Builder builder, // + Runner runner, // + Database database, // + int max_trials, // + Optional cost_model, // + Optional> measure_callbacks, // + double alpha, // + int window_size, // + support::LinearCongruentialEngine::TRandState seed); /*! * \brief Create a task scheduler with customized methods on the python-side. * \param tasks The tasks to be tuned. * \param builder The builder of the scheduler. * \param runner The runner of the scheduler. * \param database The database of the scheduler. + * \param max_trials The maximum number of trials. * \param cost_model The cost model of the scheduler. * \param measure_callbacks The measure callbacks of the scheduler. * \param f_tune The packed function of `Tune`. * \param f_initialize_task The packed function of `InitializeTask`. - * \param f_set_task_stopped The packed function of `SetTaskStopped`. - * \param f_is_task_running The packed function of `IsTaskRunning`. + * \param f_touch_task The packed function of `TouchTask`. * \param f_join_running_task The packed function of `JoinRunningTask`. * \param f_next_task_id The packed function of `NextTaskId`. * \return The task scheduler created. @@ -280,12 +290,12 @@ class TaskScheduler : public runtime::ObjectRef { Builder builder, // Runner runner, // Database database, // + int max_trials, // Optional cost_model, // Optional> measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // - PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // - PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, // + PyTaskSchedulerNode::FTouchTask f_touch_task, // PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // PyTaskSchedulerNode::FNextTaskId f_next_task_id); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode); diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 7a7599b0a4f8..1d2978c90533 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -62,7 +62,7 @@ class TuneContextNode : public runtime::Object { /*! \brief The task scheduler that owns the tune context */ const TaskSchedulerNode* task_scheduler; /*! \brief Whether the tuning task has been stopped or finished. */ - bool is_stopped; + bool is_terminated; /*! \brief The measure candidates. */ Optional> measure_candidates; /*! \brief The building results. */ @@ -81,7 +81,7 @@ class TuneContextNode : public runtime::Object { v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); - v->Visit("is_stopped", &is_stopped); + v->Visit("is_terminated", &is_terminated); v->Visit("builder_results", &builder_results); v->Visit("runner_futures", &runner_futures); v->Visit("measure_candidates", &measure_candidates); diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index fe56bb51eddd..d9a8a583ce9c 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -99,15 +99,15 @@ class LinearCongruentialEngine { * \brief Change the start random state of RNG with the seed of a new random state value. * \param rand_state The random state given in result_type. */ - void Seed(TRandState rand_state = 1) { - ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!"; - rand_state %= modulus; // Make sure the seed is within the range of modulus. - if (rand_state == 0) - rand_state = 1; // Avoid getting all 0 given the current parameter set. - else if (rand_state < 0) - rand_state += modulus; // Make sure the rand state is non-negative. - ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. - *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. + void Seed(TRandState rand_state) { + if (rand_state == -1) { + rand_state = DeviceRandom(); + } else if (rand_state == 0) { + rand_state = 1; + } + ICHECK(rand_state >= 0) << "The random state should be nonnegative"; + ICHECK(rand_state_ptr_ != nullptr); + *rand_state_ptr_ = rand_state % modulus; } /*! diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 1d9bfc9843b5..e78cef2cacf2 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -128,7 +128,7 @@ class ScheduleNode : public runtime::Object { * \brief Seed the randomness * \param seed The new random seed, -1 if use device random, otherwise non-negative */ - virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0; + virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0; /*! \brief Fork the random state */ virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0; diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index bfc5df52b1c8..c302d570c2aa 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -34,7 +34,7 @@ class EvolutionarySearch(SearchStrategy): ---------- num_trials_per_iter : int Number of trials per iteration. - num_trials_total : int + max_trials_per_task : int Total number of trials. population_size : int The initial population of traces from measured samples and randomly generated samples. @@ -53,7 +53,7 @@ class EvolutionarySearch(SearchStrategy): """ num_trials_per_iter: int - num_trials_total: int + max_trials_per_task: int population_size: int init_measured_ratio: int init_min_unmeasured: int @@ -66,7 +66,7 @@ def __init__( self, *, num_trials_per_iter: int, - num_trials_total: int, + max_trials_per_task: int, population_size: int, init_measured_ratio: float, init_min_unmeasured: int, @@ -79,7 +79,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member num_trials_per_iter, - num_trials_total, + max_trials_per_task, population_size, init_measured_ratio, init_min_unmeasured, @@ -94,7 +94,8 @@ class EvolutionarySearchConfig(NamedTuple): """Configuration for EvolutionarySearch""" num_trials_per_iter: int - num_trials_total: int + max_trials_per_task: int + max_trials_global: int population_size: int = 2048 init_measured_ratio: float = 0.2 init_min_unmeasured: int = 50 @@ -106,7 +107,7 @@ class EvolutionarySearchConfig(NamedTuple): def create_strategy(self) -> EvolutionarySearch: return EvolutionarySearch( num_trials_per_iter=self.num_trials_per_iter, - num_trials_total=self.num_trials_total, + max_trials_per_task=self.max_trials_per_task, population_size=self.population_size, init_measured_ratio=self.init_measured_ratio, init_min_unmeasured=self.init_min_unmeasured, diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py index eacc2776fcbb..ef1fd07527bd 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_func.py +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -33,23 +33,23 @@ class ReplayFunc(SearchStrategy): ---------- num_trials_per_iter : int Number of trials per iteration. - num_trials_total : int + max_trials_per_task : int Total number of trials. """ num_trials_per_iter: int - num_trials_total: int + max_trials_per_task: int def __init__( self, num_trials_per_iter: int, - num_trials_total: int, + max_trials_per_task: int, ): """Constructor""" self.__init_handle_by_constructor__( _ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint: disable=no-member num_trials_per_iter, - num_trials_total, + max_trials_per_task, ) @@ -57,7 +57,8 @@ class ReplayFuncConfig(NamedTuple): """Configuration for ReplayFunc""" num_trials_per_iter: int - num_trials_total: int + max_trials_per_task: int + max_trials_global: int def create_strategy(self) -> ReplayFunc: - return ReplayFunc(self.num_trials_per_iter, self.num_trials_total) + return ReplayFunc(self.num_trials_per_iter, self.max_trials_per_task) diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 5655038d2ead..ec4fa88b5f3e 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -18,8 +18,9 @@ from typing import NamedTuple from tvm._ffi import register_object -from .search_strategy import SearchStrategy + from .. import _ffi_api +from .search_strategy import SearchStrategy @register_object("meta_schedule.ReplayTrace") @@ -32,19 +33,19 @@ class ReplayTrace(SearchStrategy): ---------- num_trials_per_iter : int Number of trials per iteration. - num_trials_total : int + max_trials_per_task : int Total number of trials. """ num_trials_per_iter: int - num_trials_total: int + max_trials_per_task: int - def __init__(self, num_trials_per_iter: int, num_trials_total: int): + def __init__(self, num_trials_per_iter: int, max_trials_per_task: int): """Constructor""" self.__init_handle_by_constructor__( _ffi_api.SearchStrategyReplayTrace, # type: ignore # pylint: disable=no-member num_trials_per_iter, - num_trials_total, + max_trials_per_task, ) @@ -52,7 +53,8 @@ class ReplayTraceConfig(NamedTuple): """Configuration for ReplayTrace""" num_trials_per_iter: int - num_trials_total: int + max_trials_per_task: int + max_trials_global: int def create_strategy(self) -> ReplayTrace: - return ReplayTrace(self.num_trials_per_iter, self.num_trials_total) + return ReplayTrace(self.num_trials_per_iter, self.max_trials_per_task) diff --git a/python/tvm/meta_schedule/task_scheduler/__init__.py b/python/tvm/meta_schedule/task_scheduler/__init__.py index dbfe962d9966..1a67aa6f6831 100644 --- a/python/tvm/meta_schedule/task_scheduler/__init__.py +++ b/python/tvm/meta_schedule/task_scheduler/__init__.py @@ -22,3 +22,4 @@ """ from .task_scheduler import TaskScheduler, PyTaskScheduler from .round_robin import RoundRobin +from .gradient_based import GradientBased diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py new file mode 100644 index 000000000000..b0b13001382a --- /dev/null +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Gradient Based Task Scheduler""" +from typing import TYPE_CHECKING, List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from ..builder import Builder +from ..cost_model import CostModel +from ..database import Database +from ..measure_callback import MeasureCallback +from ..runner import Runner +from .task_scheduler import TaskScheduler + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.GradientBased") +class GradientBased(TaskScheduler): + """Gradient Based Task Scheduler""" + + def __init__( + self, + tasks: List["TuneContext"], + task_weights: List[float], + builder: Builder, + runner: Runner, + database: Database, + max_trials: int, + *, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + alpha: float = 0.2, + window_size: int = 3, + seed: int = -1, + ) -> None: + """Constructor. + + Parameters + ---------- + tasks : List[TuneContext] + List of tasks to schedule. + task_weights : List[float] + The weights of each task. + builder : Builder + The builder. + runner : Runner + The runner. + database : Database + The database. + max_trials : int + The maximum number of trials to run. + cost_model : CostModel, default None. + The cost model of the scheduler. + measure_callbacks : Optional[List[MeasureCallback]] = None + The list of measure callbacks of the scheduler. + alpha : float = 0.2 + The parameter alpha in gradient computation. + window_size : int = 3 + The parameter to control backward window size in gradient computation. + seed : int = -1 + The random seed. + """ + self.__init_handle_by_constructor__( + _ffi_api.TaskSchedulerGradientBased, # type: ignore # pylint: disable=no-member + tasks, + task_weights, + builder, + runner, + database, + max_trials, + cost_model, + measure_callbacks, + alpha, + window_size, + seed, + ) diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index a63d9a3f2183..16d06ab1fd72 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -16,19 +16,18 @@ # under the License. """Round Robin Task Scheduler""" -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from tvm._ffi import register_object from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback +from .. import _ffi_api from ..builder import Builder -from ..runner import Runner -from ..database import Database from ..cost_model import CostModel +from ..database import Database +from ..runner import Runner from .task_scheduler import TaskScheduler -from .. import _ffi_api - if TYPE_CHECKING: from ..tune_context import TuneContext @@ -57,6 +56,7 @@ def __init__( builder: Builder, runner: Runner, database: Database, + max_trials: int, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, ) -> None: @@ -72,6 +72,10 @@ def __init__( The runner. database : Database The database. + max_trials : int + The maximum number of trials. + cost_model : Optional[CostModel] + The cost model. measure_callbacks: Optional[List[MeasureCallback]] The list of measure callbacks of the scheduler. """ @@ -81,6 +85,7 @@ def __init__( builder, runner, database, + max_trials, cost_model, measure_callbacks, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index c60d56b39fd0..d3bc25c1e03a 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -19,15 +19,15 @@ from typing import Callable, List, Optional from tvm._ffi import register_object -from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from tvm.runtime import Object -from ..runner import Runner +from .. import _ffi_api from ..builder import Builder -from ..database import Database from ..cost_model import CostModel +from ..database import Database +from ..measure_callback import MeasureCallback +from ..runner import Runner, RunnerResult from ..tune_context import TuneContext -from .. import _ffi_api @register_object("meta_schedule.TaskScheduler") @@ -44,16 +44,24 @@ class TaskScheduler(Object): The runner of the scheduler. database: Database The database of the scheduler. + max_trials : int + The maximum number of trials allowed. + cost_model : Optional[CostModel] + The cost model used for search. measure_callbacks: List[MeasureCallback] = None The list of measure callbacks of the scheduler. + num_trials_already : int + The number of trials already conducted. """ tasks: List[TuneContext] builder: Builder runner: Runner database: Database + max_trials: int cost_model: Optional[CostModel] measure_callbacks: List[MeasureCallback] + num_trials_already: int def tune(self) -> None: """Auto-tuning.""" @@ -69,15 +77,20 @@ def next_task_id(self) -> int: """ return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member - def join_running_task(self, task_id: int) -> None: + def join_running_task(self, task_id: int) -> List[RunnerResult]: """Wait until the task is finished. Parameters ---------- task_id : int The task id to be joined. + + Returns + ------- + results : List[RunnerResult] + The list of results. """ - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member + return _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member def initialize_task(self, task_id: int) -> None: """Initialize modules of the given task. @@ -89,30 +102,15 @@ def initialize_task(self, task_id: int) -> None: """ _ffi_api.TaskSchedulerInitializeTask(self, task_id) # type: ignore # pylint: disable=no-member - def set_task_stopped(self, task_id: int) -> None: - """Set specific task to be stopped. - - Parameters - ---------- - task_id : int - The task id to be stopped. - """ - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member - - def is_task_running(self, task_id: int) -> bool: - """Check whether the task is running. + def touch_task(self, task_id: int) -> None: + """Touch the task and update its status Parameters ---------- task_id : int The task id to be checked. - - Returns - ------- - running : bool - Whether the task is running. """ - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member + _ffi_api.TaskSchedulerTouchTask(self, task_id) # type: ignore # pylint: disable=no-member @register_object("meta_schedule.PyTaskScheduler") @@ -130,12 +128,12 @@ def __init__( builder: Builder, runner: Runner, database: Database, + max_trials: int, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, f_tune: Callable = None, f_initialize_task: Callable = None, - f_set_task_stopped: Callable = None, - f_is_task_running: Callable = None, + f_touch_task: Callable = None, f_join_running_task: Callable = None, f_next_task_id: Callable = None, ): @@ -147,12 +145,12 @@ def __init__( builder, runner, database, + max_trials, cost_model, measure_callbacks, f_tune, f_initialize_task, - f_set_task_stopped, - f_is_task_running, + f_touch_task, f_join_running_task, f_next_task_id, ) @@ -173,14 +171,14 @@ class PyTaskScheduler: "builder", "runner", "database", + "max_trials", "cost_model", "measure_callbacks", ], "methods": [ "tune", "initialize_task", - "set_task_stopped", - "is_task_running", + "touch_task", "join_running_task", "next_task_id", ], @@ -192,6 +190,7 @@ def __init__( builder: Builder, runner: Runner, database: Database, + max_trials: int, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, ): @@ -199,6 +198,7 @@ def __init__( self.builder = builder self.runner = runner self.database = database + self.max_trials = max_trials self.cost_model = cost_model self.measure_callbacks = measure_callbacks @@ -217,7 +217,7 @@ def next_task_id(self) -> int: """ raise NotImplementedError - def join_running_task(self, task_id: int) -> None: + def join_running_task(self, task_id: int) -> List[RunnerResult]: """Wait until the task is finished. Parameters @@ -226,7 +226,7 @@ def join_running_task(self, task_id: int) -> None: The task id to be joined. """ # Using self._outer to replace the self pointer - _ffi_api.TaskSchedulerJoinRunningTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member + return _ffi_api.TaskSchedulerJoinRunningTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member def initialize_task(self, task_id: int) -> None: """Initialize modules of the given task. @@ -239,29 +239,13 @@ def initialize_task(self, task_id: int) -> None: # Using self._outer to replace the self pointer _ffi_api.TaskSchedulerInitializeTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member - def set_task_stopped(self, task_id: int) -> None: - """Set specific task to be stopped. - - Parameters - ---------- - task_id : int - The task id to be stopped. - """ - # Using self._outer to replace the self pointer - _ffi_api.TaskSchedulerSetTaskStopped(self._outer(), task_id) # type: ignore # pylint: disable=no-member - - def is_task_running(self, task_id: int) -> bool: - """Check whether the task is running. + def touch_task(self, task_id: int) -> None: + """Touch the task and update its status Parameters ---------- task_id : int The task id to be checked. - - Returns - ------- - running : bool - Whether the task is running. """ # Using self._outer to replace the self pointer - return _ffi_api.TaskSchedulerIsTaskRunning(self._outer(), task_id) # type: ignore # pylint: disable=no-member + _ffi_api.TaskSchedulerTouchTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index dde1b1f0489c..5859412ebbf0 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -129,9 +129,11 @@ def tune_each_task( task_scheduler = ms.tune.Parse._task_scheduler( None, [tune_context], + task_weights=[1.0], builder=ms.tune.Parse._builder(None), runner=ms.tune.Parse._runner(runner), database=database, + max_trials=config.max_trials_per_task, cost_model=ms.tune.Parse._cost_model(None), measure_callbacks=ms.tune.Parse._callbacks(None), ) @@ -167,12 +169,14 @@ def main(): alloc_repeat=alloc_repeat, max_workers=ARGS.rpc_workers, ) - lib = tune_each_task( # or ms.tune_relay + # lib = tune_each_task( + lib = ms.tune_relay( mod=mod, target=ARGS.target, config=ms.EvolutionarySearchConfig( num_trials_per_iter=64, - num_trials_total=ARGS.num_trials, + max_trials_per_task=ARGS.num_trials, + max_trials_global=ARGS.num_trials, init_min_unmeasured=50, ), runner=runner, # type: ignore diff --git a/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py index ceace160ea57..abba94ad7a5e 100644 --- a/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py @@ -102,7 +102,8 @@ def main(): target=ARGS.target, config=ms.EvolutionarySearchConfig( num_trials_per_iter=64, - num_trials_total=ARGS.num_trials, + max_trials_per_task=ARGS.num_trials, + max_trials_global=ARGS.num_trials, init_min_unmeasured=50, ), runner=runner, # type: ignore diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index ba574010152b..c65e92aec3c7 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -46,7 +46,7 @@ ReplayTraceConfig, ) from .space_generator import PostOrderApply, SpaceGenerator -from .task_scheduler import RoundRobin, TaskScheduler +from .task_scheduler import GradientBased, TaskScheduler from .tune_context import TuneContext from .utils import autotvm_silencer @@ -64,6 +64,7 @@ FnTaskScheduler = Callable[ [ List[TuneContext], + List[float], Builder, Runner, Database, @@ -393,24 +394,29 @@ def _tune_context( def _task_scheduler( task_scheduler: Union[None, TaskScheduler, FnTaskScheduler], tasks: List[TuneContext], + task_weights: List[float], builder: Builder, runner: Runner, database: Database, + max_trials: int, cost_model: CostModel, measure_callbacks: List[MeasureCallback], ): if task_scheduler is None: - return RoundRobin( + return GradientBased( tasks=tasks, + task_weights=task_weights, builder=builder, runner=runner, database=database, + max_trials=max_trials, cost_model=cost_model, measure_callbacks=measure_callbacks, ) if callable(task_scheduler): return task_scheduler( tasks, + task_weights, builder, runner, database, @@ -495,9 +501,11 @@ def tune_tir( task_scheduler = Parse._task_scheduler( task_scheduler, [tune_context], + task_weights=[1.0], builder=Parse._builder(builder), runner=Parse._runner(runner), database=database, + max_trials=config.max_trials_global, cost_model=Parse._cost_model(cost_model), measure_callbacks=Parse._callbacks(measure_callbacks), ) @@ -707,9 +715,11 @@ def tune_extracted_tasks( task_scheduler = Parse._task_scheduler( task_scheduler, tune_contexts, + task_weights=[float(t.weight) for t in extracted_tasks], builder=Parse._builder(builder), runner=Parse._runner(runner), database=database, + max_trials=config.max_trials_global, cost_model=Parse._cost_model(cost_model), measure_callbacks=Parse._callbacks(measure_callbacks), ) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 6b36ace98586..8ea1c28b2dc6 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -53,7 +53,7 @@ class _PyRunner(meta_schedule.Runner): def __init__(self, f_run: Callable = None): self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner, f_run) - class PyRunner(): + class PyRunner: _tvm_metadata = { "cls": _PyRunner, "methods": ["run"] diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc b/src/meta_schedule/measure_callback/echo_statistics.cc index ae7a4826c947..f287596ffbbb 100644 --- a/src/meta_schedule/measure_callback/echo_statistics.cc +++ b/src/meta_schedule/measure_callback/echo_statistics.cc @@ -31,14 +31,6 @@ std::string GetTaskName(const TuneContext& task, int task_id) { return os.str(); } -double GetRunMs(const Array& run_secs) { - double total = 0.0; - for (const FloatImm& i : run_secs) { - total += i->value; - } - return total * 1e3 / run_secs.size(); -} - struct TaskInfo { std::string name; double flop = 0.0; @@ -103,7 +95,7 @@ class EchoStatisticsNode : public MeasureCallbackNode { info.UpdateError(err.value(), candidate); } else { ICHECK(runner_result->run_secs.defined()); - info.Update(GetRunMs(runner_result->run_secs.value())); + info.Update(GetRunMsMedian(runner_result)); } } } diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 24d15b149e70..365d2d69225d 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -325,7 +325,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The number of trials per iteration. */ int num_trials_per_iter; /*! \brief The number of total trials. */ - int num_trials_total; + int max_trials_per_task; /*! \brief The population size in the evolutionary search. */ int population_size; /*! @@ -363,7 +363,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { // `state_` is not visited /*** Configuration: global ***/ - v->Visit("num_trials_total", &num_trials_total); + v->Visit("max_trials_per_task", &max_trials_per_task); v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("population_size", &population_size); v->Visit("num_empty_iters_before_early_stop", &num_empty_iters_before_early_stop); @@ -640,13 +640,13 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( } Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { - if (st >= self->num_trials_total) { + if (st >= self->max_trials_per_task) { return NullOpt; } int sample_num = self->num_trials_per_iter; - if (ed > self->num_trials_total) { - sample_num = self->num_trials_total - st; - ed = self->num_trials_total; + if (ed > self->max_trials_per_task) { + sample_num = self->max_trials_per_task - st; + ed = self->max_trials_per_task; } ICHECK_LT(st, ed); int pop = self->population_size; @@ -681,7 +681,7 @@ void EvolutionarySearchNode::State::NotifyRunnerResults( } SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // - int num_trials_total, // + int max_trials_per_task, // int population_size, // double init_measured_ratio, // int init_min_unmeasured, // @@ -694,7 +694,7 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, / TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); ObjectPtr n = make_object(); n->num_trials_per_iter = num_trials_per_iter; - n->num_trials_total = num_trials_total; + n->max_trials_per_task = max_trials_per_task; n->population_size = population_size; n->num_empty_iters_before_early_stop = 5; n->init_measured_ratio = init_measured_ratio; diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 7592a8a2418e..878c872a65fe 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -42,7 +42,7 @@ class ReplayFuncNode : public SearchStrategyNode { /*! \brief The number of trials per iteration. */ int num_trials_per_iter; /*! \brief The number of total trials. */ - int num_trials_total; + int max_trials_per_task; /*! \brief The module to be tuned. */ IRModule mod_{nullptr}; @@ -59,7 +59,7 @@ class ReplayFuncNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); - v->Visit("num_trials_total", &num_trials_total); + v->Visit("max_trials_per_task", &max_trials_per_task); // `space_generator_` is not visited // `mod_` is not visited // `args_info_` is not visited @@ -104,10 +104,10 @@ class ReplayFuncNode : public SearchStrategyNode { }; inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { - if (st >= self->num_trials_total) { + if (st >= self->max_trials_per_task) { return NullOpt; } - ed = std::min(ed, self->num_trials_total); + ed = std::min(ed, self->max_trials_per_task); Array result; for (int i = st; i < ed; i++) { for (;;) { @@ -136,10 +136,10 @@ inline void ReplayFuncNode::State::NotifyRunnerResults(const Array ed += self->num_trials_per_iter; } -SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) { +SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int max_trials_per_task) { ObjectPtr n = make_object(); n->num_trials_per_iter = num_trials_per_iter; - n->num_trials_total = num_trials_total; + n->max_trials_per_task = max_trials_per_task; return SearchStrategy(n); } diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 1eac10d1ad82..f17c5d6c4eb3 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -45,7 +45,7 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The number of trials per iteration. */ int num_trials_per_iter; /*! \brief The number of total trials. */ - int num_trials_total; + int max_trials_per_task; /*! \brief The module to be tuned. */ Array per_thread_mod_{nullptr}; @@ -62,7 +62,7 @@ class ReplayTraceNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); - v->Visit("num_trials_total", &num_trials_total); + v->Visit("max_trials_per_task", &max_trials_per_task); // `per_thread_mod_` is not visited // `args_info_` is not visited // `postprocs_` is not visited @@ -119,10 +119,10 @@ class ReplayTraceNode : public SearchStrategyNode { }; inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { - if (st >= self->num_trials_total) { + if (st >= self->max_trials_per_task) { return NullOpt; } - ed = std::min(ed, self->num_trials_total); + ed = std::min(ed, self->max_trials_per_task); ICHECK_LT(st, ed); std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); Array per_task_result(ed - st, MeasureCandidate{nullptr}); @@ -150,10 +150,10 @@ inline void ReplayTraceNode::State::NotifyRunnerResults(const Arraynum_trials_per_iter; } -SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_trials_total) { +SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int max_trials_per_task) { ObjectPtr n = make_object(); n->num_trials_per_iter = num_trials_per_iter; - n->num_trials_total = num_trials_total; + n->max_trials_per_task = max_trials_per_task; return SearchStrategy(n); } diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc new file mode 100644 index 000000000000..1bcebcdcc794 --- /dev/null +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +struct TaskRecord { + TuneContext task; + double weight; + double flop; + std::vector best_time_cost_history; // in ms + int trials; +}; + +/*! \brief The gradient based task scheduler. */ +class GradientBasedNode final : public TaskSchedulerNode { + public: + // Parameters used in gradient computation + double alpha; + int window_size; + + std::vector task_records_; + std::vector best_time_cost_per_task_; // in ms + int num_rounds_already_; + support::LinearCongruentialEngine::TRandState rand_state_; + + void VisitAttrs(tvm::AttrVisitor* v) { + TaskSchedulerNode::VisitAttrs(v); + v->Visit("alpha", &alpha); + v->Visit("window_size", &window_size); + // `task_records_` is not visited. + // `best_time_cost_per_task_` is not visited. + // `num_rounds_already_` is not visited. + // `rand_state_` is not visited. + } + + static constexpr const char* _type_key = "meta_schedule.GradientBased"; + TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode); + + public: + std::string TuningStatistics() const { + std::ostringstream os; + int n_tasks = task_records_.size(); + int total_trials = 0; + double total_latency = 0.0; + support::TablePrinter p; + p.Row() << "ID" + << "Name" + << "FLOP" + << "Weight" + << "Speed (GFLOPS)" + << "Latency (us)" + << "Weighted Latency (us)" + << "Trials" + << "Terminated"; + p.Separator(); + for (int i = 0; i < n_tasks; ++i) { + const TaskRecord& record = task_records_[i]; + auto row = p.Row(); + int trials = record.trials; + row << /*id=*/i // + << /*name=*/record.task->task_name.value() // + << /*flops=*/static_cast(record.flop) // + << /*weight=*/static_cast(record.weight); + if (trials == 0) { + row << /*speed=*/"N/A" << /*latency=*/"N/A" << /*weighted_latency=*/"N/A"; + } else { + double latency = record.best_time_cost_history.back() * 1000.0; + double speed = record.flop / latency / 1000.0; + double weighted_latency = latency * record.weight; + row << /*speed=*/speed << /*latency=*/latency << /*weighted_latency=*/weighted_latency; + total_latency += weighted_latency; + total_trials += trials; + } + row << trials; + if (tasks[i]->is_terminated) { + row << "Y"; + } else { + row << ""; + } + } + p.Separator(); + os << p.AsStr() // + << "\nTotal trials: " << total_trials // + << "\nTotal latency (us): " << total_latency // + << "\n"; + return os.str(); + } + + int NextTaskId() final { + int n_tasks = task_records_.size(); + // Round robin + if (num_rounds_already_ == 0) { + LOG(INFO) << "\n" << this->TuningStatistics(); + } + if (num_rounds_already_ < n_tasks) { + return num_rounds_already_++; + } + if (num_rounds_already_ == n_tasks) { + for (int i = 0; i < n_tasks; ++i) { + this->JoinRunningTask(i); + } + } + ++num_rounds_already_; + // Check running tasks + std::vector tasks_alive; + tasks_alive.reserve(n_tasks); + for (int i = 0; i < n_tasks; ++i) { + this->TouchTask(i); + if (!tasks[i]->is_terminated) { + tasks_alive.push_back(i); + } + } + if (tasks_alive.empty()) { + return -1; + } + std::vector grad; + grad.reserve(n_tasks); + for (int task_id : tasks_alive) { + const TaskRecord& record = task_records_[task_id]; + const int w = this->window_size; + int n = record.best_time_cost_history.size(); + ICHECK_GE(n, 1); + double best = record.best_time_cost_history[n - 1]; + double g1 = (n >= 1 + w) ? (record.best_time_cost_history[n - 1 - w] - best) / w : 0.0; + double g2 = best / n; + double g = alpha * g1 + (1 - alpha) * g2; + grad.push_back(g * record.weight); + } + auto max_grad = std::max_element(grad.begin(), grad.end()); + auto min_grad = std::min_element(grad.begin(), grad.end()); + int task_id = -1; + if (*max_grad == *min_grad) { + task_id = tasks_alive[tir::SampleInt(&rand_state_, 0, tasks_alive.size())]; + } else { + task_id = tasks_alive[std::distance(grad.begin(), max_grad)]; + } + if (tasks[task_id]->runner_futures.defined()) { + JoinRunningTask(task_id); + } + return task_id; + } + + Array JoinRunningTask(int task_id) final { + TaskRecord& record = task_records_[task_id]; + Array results = TaskSchedulerNode::JoinRunningTask(task_id); + double& best_time_cost = this->best_time_cost_per_task_[task_id]; + for (const RunnerResult& result : results) { + if (!result->error_msg.defined()) { + best_time_cost = std::min(best_time_cost, GetRunMsMedian(result)); + } + } + record.best_time_cost_history.push_back(best_time_cost); + record.trials += results.size(); + LOG(INFO) << "[Updated] Task #" << task_id << ": " << record.task->task_name << "\n" + << this->TuningStatistics(); + return results; + } +}; + +TaskScheduler TaskScheduler::GradientBased(Array tasks, // + Array task_weights, // + Builder builder, // + Runner runner, // + Database database, // + int max_trials, // + Optional cost_model, // + Optional> measure_callbacks, + double alpha, int window_size, + support::LinearCongruentialEngine::TRandState seed) { + CHECK_EQ(tasks.size(), task_weights.size()) + << "The size of `tasks` should have the same as `task_weights`."; + int n_tasks = tasks.size(); + std::vector task_records; + task_records.reserve(n_tasks); + for (int i = 0; i < n_tasks; ++i) { + task_records.push_back(TaskRecord{ + /*task=*/tasks[i], + /*weights=*/task_weights[i]->value, + /*flop=*/std::max(1.0, tir::EstimateTIRFlops(tasks[i]->mod.value())), + /*best_time_cost_history=*/{}, + /*trials=*/0, + }); + } + ObjectPtr n = make_object(); + n->tasks = tasks; + n->builder = builder; + n->runner = runner; + n->database = database; + n->max_trials = max_trials; + n->cost_model = cost_model; + n->measure_callbacks = measure_callbacks.value_or({}); + n->num_trials_already = 0; + n->alpha = alpha; + n->window_size = window_size; + n->task_records_ = std::move(task_records); + n->best_time_cost_per_task_ = std::vector(n_tasks, 1e100); + n->num_rounds_already_ = 0; + support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + for (const TuneContext& task : tasks) { + task->task_scheduler = n.get(); + } + return TaskScheduler(n); +} + +TVM_REGISTER_NODE_TYPE(GradientBasedNode); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased") + .set_body_typed(TaskScheduler::GradientBased); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 72989a20bcd5..a5731af1fc4d 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -38,11 +38,14 @@ class RoundRobinNode final : public TaskSchedulerNode { protected: int NextTaskId() final { int n_tasks = this->tasks.size(); + for (int i = 0; i < n_tasks; ++i) { + this->TouchTask(i); + } for (int i = 0; i < n_tasks; ++i) { task_id = (task_id + 1) % n_tasks; TuneContext task = tasks[task_id]; - if (!task->is_stopped) { - if (IsTaskRunning(task_id)) { + if (!task->is_terminated) { + if (task->runner_futures.defined()) { JoinRunningTask(task_id); } return task_id; @@ -56,6 +59,7 @@ TaskScheduler TaskScheduler::RoundRobin(Array tasks, // Builder builder, // Runner runner, // Database database, // + int max_trials, // Optional cost_model, // Optional> measure_callbacks) { ObjectPtr n = make_object(); @@ -63,8 +67,10 @@ TaskScheduler TaskScheduler::RoundRobin(Array tasks, // n->builder = builder; n->runner = runner; n->database = database; + n->max_trials = max_trials; n->cost_model = cost_model; n->measure_callbacks = measure_callbacks.value_or({}); + n->num_trials_already = 0; n->task_id = -1; for (const TuneContext& task : tasks) { task->task_scheduler = n.get(); diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index fdce470fd0ca..e30295fd1a0f 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -26,10 +26,9 @@ namespace meta_schedule { * \param builder The builder to send the candidates to. * \param context The tuning context. * \param candidates The measure candidates. - * \return An array of the builder results. */ -Array SendToBuilder(const Builder& builder, const TuneContext& context, - const Array& candidates) { +void SendToBuilder(const Builder& builder, const TuneContext& context) { + Array candidates = context->measure_candidates.value(); LOG(INFO) << "Sending " << candidates.size() << " sample(s) to builder"; Target target = context->target.value(); Array inputs; @@ -37,7 +36,7 @@ Array SendToBuilder(const Builder& builder, const TuneContext& co for (const MeasureCandidate& candidate : candidates) { inputs.push_back(BuilderInput(candidate->sch->mod(), target)); } - return builder->Build(inputs); + context->builder_results = builder->Build(inputs); } /*! @@ -48,9 +47,9 @@ Array SendToBuilder(const Builder& builder, const TuneContext& co * \param builder_results The builder results. * \return An array of the runner results. */ -Array SendToRunner(const Runner& runner, const TuneContext& context, - const Array& candidates, - const Array& builder_results) { +void SendToRunner(const Runner& runner, const TuneContext& context) { + Array candidates = context->measure_candidates.value(); + Array builder_results = context->builder_results.value(); LOG(INFO) << "Sending " << candidates.size() << " sample(s) to runner"; Target target = context->target.value(); ICHECK_EQ(candidates.size(), builder_results.size()); @@ -71,7 +70,8 @@ Array SendToRunner(const Runner& runner, const TuneContext& contex } Array futures = runner->Run(inputs); if (n_build_errors == 0) { - return futures; + context->runner_futures = futures; + return; } Array results; results.reserve(n); @@ -88,96 +88,90 @@ Array SendToRunner(const Runner& runner, const TuneContext& contex results.push_back(futures[j++]); } } - return results; + context->runner_futures = results; } void TaskSchedulerNode::InitializeTask(int task_id) { TuneContext task = this->tasks[task_id]; - LOG(INFO) << "Initializing Task #" << task_id << ": " << task->task_name << ", mod =\n" - << tir::AsTVMScript(task->mod); - this->tasks[task_id]->Initialize(); + LOG(INFO) << "Initializing Task #" << task_id << ": " << task->task_name; + CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; + CHECK(task->space_generator.defined()) + << "ValueError: Require `context.space_generator`, but it is not defined"; + CHECK(task->search_strategy.defined()) + << "ValueError: Require `context.search_strategy`, but it is not defined"; + LOG(INFO) << "\n" << tir::AsTVMScript(task->mod); + task->Initialize(); + Array design_spaces = + task->space_generator.value()->GenerateDesignSpace(task->mod.value()); + LOG(INFO) << "Total " << design_spaces.size() << " design space(s) generated"; + for (int i = 0, n = design_spaces.size(); i < n; ++i) { + tir::Schedule sch = design_spaces[i]; + tir::Trace trace = sch->trace().value(); + trace = trace->Simplified(true); + LOG(INFO) << "Design space #" << i << ":\n" + << tir::AsTVMScript(sch->mod()) << "\n" + << Concat(trace->AsPython(false), "\n"); + } + task->search_strategy.value()->PreTuning(design_spaces); } void TaskSchedulerNode::Tune() { - for (int i = 0; i < static_cast(this->tasks.size()); i++) { - TuneContext task = tasks[i]; - // Check Optional value validity. - CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; - CHECK(task->space_generator.defined()) - << "ValueError: Require `context.space_generator`, but it is not defined"; - CHECK(task->search_strategy.defined()) - << "ValueError: Require `context.search_strategy`, but it is not defined"; - InitializeTask(i); - Array design_spaces = - task->space_generator.value()->GenerateDesignSpace(task->mod.value()); - LOG(INFO) << "Total " << design_spaces.size() << " design space(s) generated"; - for (int i = 0, n = design_spaces.size(); i < n; ++i) { - tir::Schedule sch = design_spaces[i]; - tir::Trace trace = sch->trace().value(); - trace = trace->Simplified(true); - LOG(INFO) << "Design space #" << i << ":\n" - << tir::AsTVMScript(sch->mod()) << "\n" - << Concat(trace->AsPython(false), "\n"); - } - task->search_strategy.value()->PreTuning(design_spaces); + int n_tasks = this->tasks.size(); + for (int task_id = 0; task_id < n_tasks; ++task_id) { + InitializeTask(task_id); } - int running_tasks = tasks.size(); - for (int task_id; (task_id = NextTaskId()) != -1;) { + for (int task_id; num_trials_already < max_trials && (task_id = NextTaskId()) != -1;) { LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name; TuneContext task = tasks[task_id]; - ICHECK(!task->is_stopped); + ICHECK(!task->is_terminated); ICHECK(!task->runner_futures.defined()); SearchStrategy strategy = task->search_strategy.value(); if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { - Array builder_results = - SendToBuilder(this->builder, task, task->measure_candidates.value()); - task->builder_results = builder_results; - task->runner_futures = - SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results); + num_trials_already += task->measure_candidates.value().size(); + SendToBuilder(this->builder, task); + SendToRunner(this->runner, task); } else { - SetTaskStopped(task_id); + ICHECK(!task->is_terminated); + task->is_terminated = true; --running_tasks; LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; } } - ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished"; - int n_tasks = this->tasks.size(); for (int task_id = 0; task_id < n_tasks; ++task_id) { - ICHECK(!IsTaskRunning(task_id)) << "Task #" << task_id << " is still running"; TuneContext task = tasks[task_id]; + if (!task->is_terminated) { + if (task->runner_futures.defined()) { + JoinRunningTask(task_id); + } + task->is_terminated = true; + --running_tasks; + LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; + } task->search_strategy.value()->PostTuning(); } } -void TaskSchedulerNode::SetTaskStopped(int task_id) { +void TaskSchedulerNode::TouchTask(int task_id) { TuneContext task = tasks[task_id]; - ICHECK(!task->is_stopped); - task->is_stopped = true; -} - -bool TaskSchedulerNode::IsTaskRunning(int task_id) { - TuneContext task = tasks[task_id]; - if (task->is_stopped || !task->runner_futures.defined()) { - return false; - } - for (const RunnerFuture future : task->runner_futures.value()) { - if (!future->Done()) { - return true; + if (!task->is_terminated && task->runner_futures.defined()) { + for (const RunnerFuture future : task->runner_futures.value()) { + if (!future->Done()) { + return; + } } + this->JoinRunningTask(task_id); } - this->JoinRunningTask(task_id); - return false; } -void TaskSchedulerNode::JoinRunningTask(int task_id) { +Array TaskSchedulerNode::JoinRunningTask(int task_id) { TuneContext task = tasks[task_id]; ICHECK(task->runner_futures.defined()); Array futures = task->runner_futures.value(); int n = futures.size(); Array results; results.reserve(n); - for (const RunnerFuture future : task->runner_futures.value()) { + for (RunnerFuture future : futures) { results.push_back(future->Result()); } task->search_strategy.value()->NotifyRunnerResults(task, task->measure_candidates.value(), @@ -194,6 +188,7 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) { task->measure_candidates = NullOpt; task->builder_results = NullOpt; task->runner_futures = NullOpt; + return results; } TaskScheduler TaskScheduler::PyTaskScheduler( @@ -201,12 +196,12 @@ TaskScheduler TaskScheduler::PyTaskScheduler( Builder builder, // Runner runner, // Database database, // + int max_trials, // Optional cost_model, // Optional> measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // - PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // - PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, // + PyTaskSchedulerNode::FTouchTask f_touch_task, // PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // PyTaskSchedulerNode::FNextTaskId f_next_task_id) { ObjectPtr n = make_object(); @@ -214,16 +209,17 @@ TaskScheduler TaskScheduler::PyTaskScheduler( n->builder = builder; n->runner = runner; n->database = database; + n->max_trials = max_trials; n->cost_model = cost_model; if (measure_callbacks.defined()) { n->measure_callbacks = measure_callbacks.value(); } else { n->measure_callbacks = {}; } + n->num_trials_already = 0; n->f_tune = f_tune; n->f_initialize_task = f_initialize_task; - n->f_set_task_stopped = f_set_task_stopped; - n->f_is_task_running = f_is_task_running; + n->f_touch_task = f_touch_task; n->f_join_running_task = f_join_running_task; n->f_next_task_id = f_next_task_id; return TaskScheduler(n); @@ -237,10 +233,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") .set_body_method(&TaskSchedulerNode::Tune); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerInitializeTask") .set_body_method(&TaskSchedulerNode::InitializeTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerSetTaskStopped") - .set_body_method(&TaskSchedulerNode::SetTaskStopped); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerIsTaskRunning") - .set_body_method(&TaskSchedulerNode::IsTaskRunning); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") + .set_body_method(&TaskSchedulerNode::TouchTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") .set_body_method(&TaskSchedulerNode::JoinRunningTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 31a913e80798..ba8ee58c5ba4 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -42,12 +42,9 @@ TuneContext::TuneContext(Optional mod, n->postprocs = postprocs.value_or({}); n->mutator_probs = mutator_probs.value_or({}); n->task_name = task_name; - if (rand_state == -1) { - rand_state = support::LinearCongruentialEngine::DeviceRandom(); - } support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); n->num_threads = num_threads; - n->is_stopped = false; + n->is_terminated = false; n->runner_futures = NullOpt; n->measure_candidates = NullOpt; data_ = std::move(n); diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 90d1e4755cac..2ee18a8668be 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -38,6 +38,7 @@ #include #include +#include #include #include @@ -45,6 +46,7 @@ #include "../support/array.h" #include "../support/base64.h" #include "../support/nd_int_set.h" +#include "../support/table_printer.h" #include "../support/utils.h" #include "../tir/schedule/primitive.h" #include "../tir/schedule/utils.h" @@ -366,6 +368,27 @@ inline int GetTargetNumCores(const Target& target) { return num_cores; } +/*! + * \brief Get the median of the running time from RunnerResult in millisecond + * \param results The results from RunnerResult + * \return The median of the running time in millisecond + */ +inline double GetRunMsMedian(const RunnerResult& runner_result) { + Array run_secs = runner_result->run_secs.value(); + ICHECK(!run_secs.empty()); + std::vector v; + v.reserve(run_secs.size()); + std::transform(run_secs.begin(), run_secs.end(), std::back_inserter(v), + [](const FloatImm& f) -> double { return f->value; }); + std::sort(v.begin(), v.end()); + int n = v.size(); + if (n % 2 == 0) { + return (v[n / 2] + v[n / 2 + 1]) * 0.5 * 1000.0; + } else { + return v[n / 2] * 1000.0; + } +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/support/table_printer.h b/src/support/table_printer.h new file mode 100644 index 000000000000..364e3f4ba6bd --- /dev/null +++ b/src/support/table_printer.h @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SUPPORT_TABLE_PRINTER_H_ +#define TVM_SUPPORT_TABLE_PRINTER_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace support { + +/*! + * \brief TablePrinter is a helper class to print a table. + * + * \code + * + * TablePrinter p; + * p.Row() << "ID" + * << "Latency (ms)" + * << "Speed (GFLOPS)" + * << "Trials"; + * p.Separator(); + * p.Row() << 0 << 0.072 << 4208.59 << 6656; + * p.Row() << 1 << 0.020 << 3804.24 << 7296; + * p.Row() << 2 << 0.003 << 1368.10 << 320; + * p.Row() << 3 << 0.010 << 117.75 << 128; + * p.Row() << 4 << 0.002 << 23.75 << 320; + * p.Row() << 5 << 0.004 << 1696.18 << 704; + * p.Row() << 6 << 0.002 << 69.89 << 320; + * p.Row() << 7 << 0.047 << 6394.42 << 4352; + * p.Separator(); + * std::cout << tab.AsStr(); + * + * \endcode + */ +class TablePrinter { + struct Line; + + public: + /*! \brief Create a new row */ + inline Line Row(); + /*! \brief Create a row separator */ + inline void Separator(); + /*! \brief Converts TablePrinter to a string */ + inline std::string AsStr() const; + + private: + std::vector> tab_; + friend struct Line; + + /*! \brief A helper class to print a specific row in the table */ + struct Line { + inline Line& operator<<(int x); + inline Line& operator<<(double x); + inline Line& operator<<(const std::string& x); + + private: + TablePrinter* p; + friend class TablePrinter; + }; +}; + +inline TablePrinter::Line& TablePrinter::Line::operator<<(int x) { + p->tab_.back().push_back(std::to_string(x)); + return *this; +} + +inline TablePrinter::Line& TablePrinter::Line::operator<<(double x) { + std::ostringstream os; + os << std::fixed << std::setprecision(4) << x; + p->tab_.back().push_back(os.str()); + return *this; +} + +inline TablePrinter::Line& TablePrinter::Line::operator<<(const std::string& x) { + p->tab_.back().push_back(x); + return *this; +} + +inline TablePrinter::Line TablePrinter::Row() { + tab_.emplace_back(); + Line line; + line.p = this; + return line; +} + +inline void TablePrinter::Separator() { tab_.emplace_back(); } + +inline std::string TablePrinter::AsStr() const { + constexpr char kRowSep = '-'; + constexpr char kColSep = '|'; + if (tab_.empty()) return ""; + std::vector column_width; + for (const std::vector& row : tab_) { + if (row.size() > column_width.size()) { + column_width.resize(row.size(), 0); + } + for (size_t i = 0; i < row.size(); ++i) { + column_width[i] = std::max(column_width[i], row[i].size()); + } + } + ICHECK(!column_width.empty()); + size_t total_width = + std::accumulate(column_width.begin(), column_width.end(), 0) + 3 * column_width.size() - 1; + bool is_first = true; + std::ostringstream os; + for (const std::vector& row : tab_) { + if (is_first) { + is_first = false; + } else { + os << '\n'; + } + if (row.empty()) { + os << std::string(total_width, kRowSep); + continue; + } + for (size_t i = 0; i < column_width.size(); ++i) { + if (i != 0) { + os << kColSep; + } + std::string s = (i < row.size()) ? row[i] : ""; + os << std::string(column_width[i] + 1 - s.size(), ' ') << s << ' '; + } + } + return os.str(); +} + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_TABLE_PRINTER_H_ diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index e261cf2a03de..59a19631fc09 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -214,9 +214,6 @@ Schedule ConcreteScheduleNode::Copy() { /******** Schedule: Schedule: Sampling ********/ void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { - if (seed == -1) { - seed = std::random_device()(); - } support::LinearCongruentialEngine(&rand_state_).Seed(seed); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 59764e36fe70..4534406d79cf 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -62,7 +62,7 @@ class ConcreteScheduleNode : public ScheduleNode { ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } Schedule Copy() override; - void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final; + void Seed(support::LinearCongruentialEngine::TRandState seed) final; support::LinearCongruentialEngine::TRandState ForkSeed() final; public: diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index 73640bdf74f6..df8d0fe38315 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -16,19 +16,18 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import re -from typing import List from random import random +from typing import List import pytest import tvm from tvm.ir import IRModule, assert_structural_equal from tvm.meta_schedule.builder import BuilderResult from tvm.meta_schedule.measure_callback import PyMeasureCallback -from tvm.meta_schedule.builder import BuilderResult from tvm.meta_schedule.runner import RunnerResult -from tvm.meta_schedule.testing import DummyDatabase, DummyRunner, DummyBuilder from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.task_scheduler import RoundRobin, TaskScheduler +from tvm.meta_schedule.testing import DummyBuilder, DummyDatabase, DummyRunner from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.tir.schedule import Schedule @@ -79,7 +78,7 @@ def apply( measure_callback = FancyMeasureCallback() measure_callback.apply( - RoundRobin([], DummyBuilder(), DummyRunner(), DummyDatabase()), + RoundRobin([], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], @@ -103,7 +102,7 @@ def apply( measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( - RoundRobin([], DummyBuilder(), DummyRunner(), DummyDatabase()), + RoundRobin([], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 663614371eeb..ca9c50b521be 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -83,9 +83,11 @@ def _schedule_matmul(sch: Schedule): @pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name num_trials_per_iter = 7 - num_trials_total = 20 + max_trials_per_task = 20 - strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + strategy = TestClass( + num_trials_per_iter=num_trials_per_iter, max_trials_per_task=max_trials_per_task + ) context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) context.space_generator.initialize_with_tune_context(context) spaces = context.space_generator.generate_design_space(context.mod) @@ -119,11 +121,11 @@ def _schedule_matmul_small(sch: Schedule): _, _ = sch.split(k, sch.sample_perfect_tile(k, n=2)) num_trials_per_iter = 10 - num_trials_total = 2000 + max_trials_per_task = 2000 strategy = EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, - num_trials_total=num_trials_total, + max_trials_per_task=max_trials_per_task, population_size=5, init_measured_ratio=0.1, init_min_unmeasured=50, @@ -148,6 +150,7 @@ def _schedule_matmul_small(sch: Schedule): database=DummyDatabase(), cost_model=ms.cost_model.RandomModel(), measure_callbacks=[], + max_trials=1, ) context.space_generator.initialize_with_tune_context(context) spaces = context.space_generator.generate_design_space(context.mod) @@ -180,11 +183,11 @@ def _schedule_matmul_empty(sch: Schedule): return sch num_trials_per_iter = 10 - num_trials_total = 100 + max_trials_per_task = 100 strategy = EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, - num_trials_total=num_trials_total, + max_trials_per_task=max_trials_per_task, population_size=5, init_measured_ratio=0.1, init_min_unmeasured=50, @@ -209,6 +212,7 @@ def _schedule_matmul_empty(sch: Schedule): database=DummyDatabase(), cost_model=ms.cost_model.RandomModel(), measure_callbacks=[], + max_trials=1, ) context.space_generator.initialize_with_tune_context(context) spaces = context.space_generator.generate_design_space(context.mod) diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index e49c35fa445c..26a2733980c0 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -17,31 +17,33 @@ """ Test Meta Schedule Task Scheduler """ import random -import weakref import sys -from typing import List +import weakref +from typing import Set import pytest import tvm from tvm._ffi.base import TVMError -from tvm.ir import IRModule from tvm.meta_schedule import TuneContext, measure_callback from tvm.meta_schedule.search_strategy import ReplayTrace from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin +from tvm.meta_schedule.task_scheduler import GradientBased, PyTaskScheduler, RoundRobin +from tvm.meta_schedule.testing import DummyBuilder, DummyDatabase, DummyRunner from tvm.meta_schedule.utils import derived_object -from tvm.meta_schedule.testing import DummyDatabase, DummyBuilder, DummyRunner, DummyRunnerFuture from tvm.script import tir as T from tvm.tir import Schedule - # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @tvm.script.ir_module class MatmulModule: @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + def main( # type: ignore + a: T.handle, + b: T.handle, + c: T.handle, + ) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") @@ -50,14 +52,18 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): - C[vi, vj] = 0.0 + C[vi, vj] = 0.0 # type: ignore C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module class MatmulReluModule: @T.prim_func - def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument + def main( # type: ignore + a: T.handle, + b: T.handle, + d: T.handle, + ) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") @@ -67,18 +73,22 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): - C[vi, vj] = 0.0 + C[vi, vj] = 0.0 # type: ignore C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) - D[vi, vj] = T.max(C[vi, vj], 0.0) + D[vi, vj] = T.max(C[vi, vj], 0.0) # type: ignore @tvm.script.ir_module class BatchMatmulModule: @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + def main( # type: ignore + a: T.handle, + b: T.handle, + c: T.handle, + ) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) @@ -87,7 +97,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s with T.block("matmul"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) with T.init(): - C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = 0.0 # type: ignore C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] @@ -117,37 +127,36 @@ def _schedule_batch_matmul(sch: Schedule): @derived_object class MyTaskScheduler(PyTaskScheduler): - done = set() + done: Set = set() def next_task_id(self) -> int: while len(self.done) != len(self.tasks): x = random.randint(0, len(self.tasks) - 1) task = self.tasks[x] - if not task.is_stopped: + if not task.is_terminated: """Calling base func via following route: Python side: - PyTaskScheduler does not have `_is_task_running` - Call TaskScheduler's `is_task_running`, which calls ffi + PyTaskScheduler does not have `_touch_task` + Call TaskScheduler's `touch_task`, which calls ffi C++ side: - The ffi calls TaskScheduler's `is_task_running` + The ffi calls TaskScheduler's `touch_task` But it is overridden in PyTaskScheduler PyTaskScheduler checks if the function is overridden in python If not, it returns the TaskScheduler's vtable, calling - TaskScheduler::IsTaskRunning + TaskScheduler::TouchTask """ - if self.is_task_running(x): + if task.runner_futures is not None: self.join_running_task(x) return x - else: - self.done.add(x) + self.done.add(x) return -1 def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 - num_trials_total = 10 + max_trials_per_task = 10 sch_fn = ScheduleFn(sch_fn=_schedule_matmul) - replay = ReplayTrace(num_trials_per_iter, num_trials_total) + replay = ReplayTrace(num_trials_per_iter, max_trials_per_task) task = TuneContext( MatmulModule, target=tvm.target.Target("llvm"), @@ -163,20 +172,21 @@ def test_meta_schedule_task_scheduler_single(): DummyRunner(), database, measure_callbacks=[measure_callback.AddToDatabase()], + max_trials=max_trials_per_task, ) round_robin.tune() - assert len(database) == num_trials_total + assert len(database) == max_trials_per_task def test_meta_schedule_task_scheduler_multiple(): num_trials_per_iter = 6 - num_trials_total = 101 + max_trials_per_task = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="Matmul", rand_state=42, ), @@ -184,7 +194,7 @@ def test_meta_schedule_task_scheduler_multiple(): MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), @@ -192,7 +202,7 @@ def test_meta_schedule_task_scheduler_multiple(): BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="BatchMatmul", rand_state=0x114514, ), @@ -204,9 +214,10 @@ def test_meta_schedule_task_scheduler_multiple(): DummyRunner(), database, measure_callbacks=[measure_callback.AddToDatabase()], + max_trials=max_trials_per_task * len(tasks), ) round_robin.tune() - assert len(database) == num_trials_total * len(tasks) + assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert ( len( @@ -215,7 +226,7 @@ def test_meta_schedule_task_scheduler_multiple(): 100000, ) ) - == num_trials_total + == max_trials_per_task ) @@ -225,7 +236,7 @@ class NIETaskScheduler(PyTaskScheduler): pass with pytest.raises(TVMError, match="PyTaskScheduler's NextTaskId method not implemented!"): - scheduler = NIETaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase()) + scheduler = NIETaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase(), 1) scheduler.next_task_id() @@ -240,6 +251,7 @@ def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid measure_callbacks=[ measure_callback.AddToDatabase(), ], + max_trials=10, ) test = weakref.ref(scheduler) # test if it can be destructed successfully del scheduler @@ -249,13 +261,13 @@ def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name num_trials_per_iter = 6 - num_trials_total = 101 + max_trials_per_task = 101 tasks = [ TuneContext( MatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="Matmul", rand_state=42, ), @@ -263,7 +275,7 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: d MatmulReluModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), @@ -271,7 +283,7 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: d BatchMatmulModule, target=tvm.target.Target("llvm"), space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), task_name="BatchMatmul", rand_state=0x114514, ), @@ -285,9 +297,10 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: d measure_callbacks=[ measure_callback.AddToDatabase(), ], + max_trials=max_trials_per_task * len(tasks), ) scheduler.tune() - assert len(database) == num_trials_total * len(tasks) + assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert ( len( @@ -296,7 +309,56 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: d 100000, ) ) - == num_trials_total + == max_trials_per_task + ) + + +def test_meta_schedule_task_scheduler_multiple_gradient_based(): + num_trials_per_iter = 6 + max_trials_per_task = 101 + tasks = [ + TuneContext( + MatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + task_name="Matmul", + rand_state=42, + ), + TuneContext( + MatmulReluModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + task_name="MatmulRelu", + rand_state=0xDEADBEEF, + ), + TuneContext( + BatchMatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + task_name="BatchMatmul", + rand_state=0x114514, + ), + ] + database = DummyDatabase() + gradient_based = GradientBased( + tasks, + task_weights=[1.0, 1.0, 1.0], + builder=DummyBuilder(), + runner=DummyRunner(), + database=database, + measure_callbacks=[measure_callback.AddToDatabase()], + seed=0x20220214, + max_trials=max_trials_per_task * len(tasks), + ) + gradient_based.tune() + assert len(database) == max_trials_per_task * len(tasks) + for task in tasks: + assert ( + len(database.get_top_k(database.commit_workload(task.mod), 10000)) + == max_trials_per_task ) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index c6b08500fbe2..389f6c3719aa 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -17,26 +17,31 @@ # pylint: disable=missing-docstring import logging import tempfile -from typing import List from os import path as osp +from typing import List + import numpy as np import pytest import tvm from tvm import relay, tir +from tvm._ffi import register_func from tvm.contrib import graph_executor from tvm.ir import IRModule -from tvm.tir.schedule import BlockRV, Schedule -from tvm.tir.schedule.trace import Trace from tvm.meta_schedule import ReplayTraceConfig -from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload, JSONDatabase +from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload from tvm.meta_schedule.integration import ApplyHistoryBest from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.meta_schedule.tune import tune_relay, tune_extracted_tasks, extract_task_from_relay, Parse +from tvm.meta_schedule.tune import ( + Parse, + extract_task_from_relay, + tune_extracted_tasks, + tune_relay, +) from tvm.meta_schedule.utils import derived_object -from tvm.target.target import Target from tvm.script import tir as T -from tvm._ffi import register_func -import tempfile +from tvm.target.target import Target +from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.schedule.trace import Trace logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -143,7 +148,7 @@ def test_meta_schedule_tune_relay( target=target, config=ReplayTraceConfig( num_trials_per_iter=32, - num_trials_total=32, + max_trials_per_task=32, ), work_dir=work_dir, database=JSONDatabase( diff --git a/tests/python/unittest/test_meta_schedule_tune_te.py b/tests/python/unittest/test_meta_schedule_tune_te.py index a07bf1760346..e0a7a8190419 100644 --- a/tests/python/unittest/test_meta_schedule_tune_te.py +++ b/tests/python/unittest/test_meta_schedule_tune_te.py @@ -24,7 +24,6 @@ from tvm.target.target import Target from tvm.tir import Schedule - logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -37,7 +36,7 @@ def test_tune_matmul(): target=Target("llvm --num-cores=16"), config=ReplayTraceConfig( num_trials_per_iter=32, - num_trials_total=32, + max_trials_per_task=32, ), work_dir=work_dir, ) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index efa1183814c8..6a80d895dfdc 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -18,17 +18,16 @@ import logging import tempfile -import tvm import pytest -from tvm.meta_schedule import ReplayTraceConfig, tune_tir -from tvm.meta_schedule.tune_context import TuneContext -from tvm.meta_schedule import schedule_rule, postproc +import tvm +from tvm.meta_schedule import ReplayTraceConfig, postproc, schedule_rule, tune_tir from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.tune_context import TuneContext from tvm.script import tir as T from tvm.target.target import Target from tvm.te.operation import create_prim_func from tvm.tir import Schedule -from tvm.meta_schedule.testing import te_workload logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -61,7 +60,7 @@ def test_tune_matmul_cpu(): target=Target("llvm --num-cores=16"), config=ReplayTraceConfig( num_trials_per_iter=32, - num_trials_total=32, + max_trials_per_task=32, ), work_dir=work_dir, ) @@ -80,7 +79,7 @@ def test_tune_matmul_cuda(): target=Target("nvidia/geforce-rtx-3070"), config=ReplayTraceConfig( num_trials_per_iter=32, - num_trials_total=32, + max_trials_per_task=32, ), work_dir=work_dir, ) @@ -98,14 +97,14 @@ def test_tune_matmul_cuda_tensor_core(): target = Target("nvidia/geforce-rtx-3070") config = ReplayTraceConfig( num_trials_per_iter=32, - num_trials_total=320, + max_trials_per_task=320, ) class DefaultTensorCore: @staticmethod def _sch_rules(): - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - schedule_rule as M, + from tvm.meta_schedule import ( + schedule_rule as M, # pylint: disable=import-outside-toplevel ) return [ @@ -154,8 +153,8 @@ def _sch_rules(): @staticmethod def _postproc(): - from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel - postproc as M, + from tvm.meta_schedule import ( + postproc as M, # pylint: disable=import-outside-toplevel ) return [ @@ -183,8 +182,8 @@ def _postproc(): print(sch.mod.script()) print(sch.trace) - from tvm.contrib import nvcc import numpy as np + from tvm.contrib import nvcc ctx = tvm.gpu(0) if nvcc.have_tensorcore(ctx.compute_version):