From 2dec2dd9a7836e142effa9af2f4ff7a3bb3b0d44 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 28 Dec 2020 09:54:05 +0800 Subject: [PATCH] [AutoScheduler] Update layout rewrite option setting for measuring (#7156) * Add layout rewrite options for measure * Update schedule for inserted transform stage * Set layout rewrite when tuning for network * Update the log version --- include/tvm/auto_scheduler/measure_record.h | 2 +- include/tvm/auto_scheduler/search_task.h | 6 +++- python/tvm/auto_scheduler/compute_dag.py | 36 ++++++++++++++++++- python/tvm/auto_scheduler/measure.py | 4 +-- .../tvm/auto_scheduler/relay_integration.py | 16 ++++----- python/tvm/auto_scheduler/search_task.py | 27 ++++++++++---- src/auto_scheduler/compute_dag.cc | 20 ++++++++--- src/auto_scheduler/feature.cc | 6 ++-- src/auto_scheduler/measure_record.cc | 12 ++++++- src/auto_scheduler/search_task.cc | 10 ++++-- 10 files changed, 110 insertions(+), 29 deletions(-) diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index 4d7952f74b40..ec40611d49b4 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,7 +34,7 @@ namespace tvm { namespace auto_scheduler { -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*) /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 60e721bd4389..9e7d3aa2cd32 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -118,6 +118,8 @@ class SearchTaskNode : public Object { Target target_host; /*! \brief Hardware parameters used in this search task. */ HardwareParams hardware_params; + /*! \brief The layout rewrite option used for measuring programs. */ + LayoutRewriteOption layout_rewrite_option; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("compute_dag", &compute_dag); @@ -125,6 +127,7 @@ class SearchTaskNode : public Object { v->Visit("target", &target); v->Visit("target_host", &target_host); v->Visit("hardware_params", &hardware_params); + v->Visit("layout_rewrite_option", &layout_rewrite_option); } static constexpr const char* _type_key = "auto_scheduler.SearchTask"; @@ -144,9 +147,10 @@ class SearchTask : public ObjectRef { * \param target The target device of this search task. * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. + * \param layout_rewrite_option The layout rewrite option used for measuring programs. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, - Optional hardware_params); + Optional hardware_params, LayoutRewriteOption layout_rewrite_option); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index d8a242260285..a7f200aa5cdd 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -32,7 +32,12 @@ class LayoutRewriteOption: - """Options for applying layout rewrite.""" + """ + Options for applying layout rewrite. + + The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, + and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network. + """ # Do not perform layout rewrite NO_REWRITE = 0 @@ -44,6 +49,35 @@ class LayoutRewriteOption: # so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay. REWRITE_FOR_PRE_TRANSFORMED = 2 + @staticmethod + def get_target_default(target, in_relay_integration=False): + """Get the default layout rewrite option for the specified target. + Currently we only enable layout rewrite for cpu / mali backend for now + + Parameters + ---------- + target: tvm.target.Target + The compilation target. + in_relay_integration: bool + If this check is ask for relay integration. + + Returns + ------- + layout_rewrite_option: LayoutRewriteOption + The default layout rewrite option for the specified target. + """ + layout_rewrite_option = LayoutRewriteOption.NO_REWRITE + if target.kind.name == "llvm" or ( + "device" in target.attrs and target.attrs["device"] == "mali" + ): + layout_rewrite_option = ( + LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED + if in_relay_integration + else LayoutRewriteOption.INSERT_TRANSFORM_STAGE + ) + + return layout_rewrite_option + @tvm._ffi.register_object("auto_scheduler.ComputeDAG") class ComputeDAG(Object): diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 24a757746ed6..2f177a242835 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -53,7 +53,6 @@ make_traceback_info, request_remote, ) -from .compute_dag import LayoutRewriteOption from .workload_registry import ( serialize_workload_registry_entry, deserialize_workload_registry_entry, @@ -211,6 +210,7 @@ def recover_measure_input(inp, rebuild_state=False): target=task.target, target_host=task.target_host, hardware_params=task.hardware_params, + layout_rewrite_option=task.layout_rewrite_option, ) if rebuild_state: @@ -576,7 +576,7 @@ def _timed_func(inp_serialized, build_func, verbose): try: sch, args = task.compute_dag.apply_steps_from_state( - inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED + inp.state, layout_rewrite=task.layout_rewrite_option ) # pylint: disable=broad-except except Exception: diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 2b26fc4931bd..3287f3d4a1e5 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -33,7 +33,7 @@ from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import expr as _expr from . import _ffi_api -from .compute_dag import ComputeDAG +from .compute_dag import ComputeDAG, LayoutRewriteOption from .dispatcher import DispatchContext from .search_task import SearchTask from .workload_registry import register_workload_tensors @@ -126,6 +126,9 @@ def extract_tasks( target=target, target_host=target_host, hardware_params=hardware_params, + # When auto scheduler is used in end to end network, try to apply layout rewrite + # to improve the overall performance + layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True), ) ) weights.append(use_count_dict[ccache_key] + 1) @@ -259,13 +262,7 @@ def auto_schedule_topi(outs, has_complex_op): key = register_workload_tensors(dag.hash_key(), io_tensors) - # only enable layout rewrite for cpu / mali backend target = tvm.target.Target.current() - enable_layout_rewrite_targets = ["cpu", "mali"] - enable_layout_rewrite = any( - enable_layout_rewrite_target in target.keys - for enable_layout_rewrite_target in enable_layout_rewrite_targets - ) env = TracingEnvironment.current if env is None: @@ -284,7 +281,10 @@ def auto_schedule_topi(outs, has_complex_op): schedule = te.create_schedule([x.op for x in outs]) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode - if enable_layout_rewrite and has_layout_free: + if ( + LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE + and has_layout_free + ): dispatch_ctx = DispatchContext.current state = dispatch_ctx.query(target, key, has_complex_op, dag) if state is None: diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index be83e06bb89d..bfa596a1dc61 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -178,6 +178,13 @@ class SearchTask(Object): The target host device of this search task. hardware_params : Optional[HardwareParams] Hardware parameters used in this search task. + layout_rewrite_option : Optional[LayoutRewriteOption] + The layout rewrite option used for measuring programs. If None, the default value will be + set depending on the specified target. + Auto_scheduler will find a better schedule for the specified layout rewrite option. + The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone + op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a + network. Examples -------- @@ -204,6 +211,7 @@ def __init__( target=None, target_host=None, hardware_params=None, + layout_rewrite_option=None, ): assert ( func is not None or workload_key is not None @@ -221,7 +229,13 @@ def __init__( target_host = Target(target_host) self.__init_handle_by_constructor__( - _ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params + _ffi_api.SearchTask, + compute_dag, + workload_key, + target, + target_host, + hardware_params, + layout_rewrite_option or LayoutRewriteOption.get_target_default(target), ) def tune(self, tuning_options, search_policy=None): @@ -250,6 +264,7 @@ def apply_best(self, log_file, layout_rewrite_option=None): layout_rewrite_option : Optional[LayoutRewriteOption] The layout rewrite option. + Returns ------- A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. @@ -260,11 +275,9 @@ def apply_best(self, log_file, layout_rewrite_option=None): "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) ) - if layout_rewrite_option is None: - layout_rewrite_option = LayoutRewriteOption.NO_REWRITE - if self.target.kind.name == "llvm": - layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE - sch, args = self.compute_dag.apply_steps_from_state(inp.state, layout_rewrite_option) + sch, args = self.compute_dag.apply_steps_from_state( + inp.state, layout_rewrite_option or self.layout_rewrite_option + ) return sch, args def print_best(self, log_file, print_mode="schedule"): @@ -305,6 +318,7 @@ def __getstate__(self): "target": self.target, "target_host": self.target_host, "hardware_params": self.hardware_params, + "layout_rewrite_option": self.layout_rewrite_option, } def __setstate__(self, state): @@ -327,6 +341,7 @@ def __setstate__(self, state): state["target"], state["target_host"], state["hardware_params"], + state["layout_rewrite_option"], ) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 64114c8331b8..b65878225f5a 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -998,11 +998,20 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, transform_steps->Set(i, std::move(step)); } } + + // Add schedule for the new added transform stage Array to_fuse; - for (size_t i = 0; i < new_shape.size() - 1; i++) { - to_fuse.push_back(i); + + if (new_shape.size() >= 5) { + to_fuse.push_back(0); + to_fuse.push_back(1); + to_fuse.push_back(2); + transform_steps->push_back(FuseStep(stage_id, to_fuse)); + } else if (new_shape.size() >= 3) { + to_fuse.push_back(0); + to_fuse.push_back(1); + transform_steps->push_back(FuseStep(stage_id, to_fuse)); } - transform_steps->push_back(FuseStep(stage_id, to_fuse)); transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel)); } @@ -1024,7 +1033,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, } original_compute_op = op; CHECK(!new_compute_op.defined()); - new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body); + auto new_attrs = pop->attrs; + new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout)); + new_attrs.Set("new_placeholder_layout", tvm::String(new_layout)); + new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body); } } } diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 53287a0eddeb..47b9fb60aab4 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1398,7 +1398,8 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int // rebuild task Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target, - cur_inp->task->target_host, cur_inp->task->hardware_params); + cur_inp->task->target_host, cur_inp->task->hardware_params, + cur_inp->task->layout_rewrite_option); task_id = task_cache.size(); // compute min cost for each task @@ -1465,7 +1466,8 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, // rebuild task for incomplete measure pairs read from file Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, - inputs[i]->task->target_host, inputs[i]->task->hardware_params); + inputs[i]->task->target_host, inputs[i]->task->hardware_params, + inputs[i]->task->layout_rewrite_option); } task_id = task_cache.size(); diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index faf3fca4cfc4..1120f437b176 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -165,12 +165,16 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArrayItem(*data.hardware_params.get()); if (data.target_host.defined()) { writer->WriteArrayItem(data.target_host->str()); + } else { + writer->WriteArrayItem(std::string("")); } + writer->WriteArrayItem(static_cast(data.layout_rewrite_option)); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { bool s; std::string str_value; + int int_value; auto hardware_params_node = ::tvm::make_object<::tvm::auto_scheduler::HardwareParamsNode>(); reader->BeginArray(); s = reader->NextArrayItem(); @@ -188,7 +192,13 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node); if (s) { reader->Read(&str_value); - data->target_host = ::tvm::Target(str_value); + if (!str_value.empty()) { + data->target_host = ::tvm::Target(str_value); + } + s = reader->NextArrayItem(); + ICHECK(s); + reader->Read(&int_value); + data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); s = reader->NextArrayItem(); ICHECK(!s); } diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 93f34609cbbc..0abee16fceab 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -113,7 +113,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target } SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, - Target target_host, Optional hardware_params) { + Target target_host, Optional hardware_params, + LayoutRewriteOption layout_rewrite_option) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -125,6 +126,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); } + node->layout_rewrite_option = layout_rewrite_option; data_ = std::move(node); } @@ -139,8 +141,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, - Target target_host, Optional hardware_params) { - return SearchTask(compute_dag, workload_key, target, target_host, hardware_params); + Target target_host, Optional hardware_params, + int layout_rewrite_option) { + return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, + LayoutRewriteOption(layout_rewrite_option)); }); } // namespace auto_scheduler