From 4cd3a1657c4e2e13abe7281b7cdef5dff73b37ee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Mar 2022 18:43:15 +0900 Subject: [PATCH] removed create_schedule stuff --- src/relay/backend/te_compiler_cache.cc | 76 +++++++++++++------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 2c2042859ddb..fc3a3ab335f4 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -46,6 +46,7 @@ #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" #include "tvm/relay/op_strategy.h" +#include "tvm/tir/function.h" #include "utils.h" namespace tvm { @@ -115,7 +116,7 @@ Array GetShape(const Array& shape) { return res; } -// Construct a schedule for a given Relay primitive function and target. +// Lowers Relay primitive Function to TE Compute class LowerToTECompute : public backend::MemoizedExprTranslator> { public: explicit LowerToTECompute(Target target) @@ -133,7 +134,21 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorVisitExpr(relay_func->body); + + Array outputs = this->VisitExpr(relay_func->body); + + candidate_name_ = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes + if (candidate_name_.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name_.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hex << std::hash{}(candidate_name_) << "_"; + candidate_name_ = truncated_name.str(); + } + + return outputs; } Array VisitExpr_(const VarNode* op) final { @@ -260,11 +275,12 @@ class LowerToTECompute : public backend::MemoizedExprTranslator fn_inputs_; Array scalars_; std::unordered_map constant_tensors_; - std::ostringstream readable_name_stream_; + std::string candidate_name_; OpImplementation anchor_implementation_; private: tvm::Target target_; + std::ostringstream readable_name_stream_; // Index of the global constants static int const_index; // Cache device copy op for equivalence checking to reduce registry lookup @@ -277,33 +293,20 @@ int LowerToTECompute::const_index = 0; // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : ExprVisitor { public: - explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), create_schedule_(create_schedule) { + explicit ScheduleBuilder(Target target) : target_(target) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - use_meta_schedule_ = backend::IsMetaScheduleEnabled(); } CachedFunc Create(const Function& relay_func, std::function renamer) { LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); Array fn_inputs = lower_te_compute.fn_inputs_; - std::string candidate_name = lower_te_compute.readable_name_stream_.str(); VisitExpr(relay_func->body); - constexpr static size_t kMaxFuncNameLength = 80; - // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME - // whenever the value of kMaxFuncNameLength changes - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - // TODO(mbs): This should be the definitive global by which the PrimFunc is known and // no other GlobalVar ctors should appear inside the lowering machinery. - auto prim_fn_var = GlobalVar(renamer(candidate_name)); + auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_)); prim_fn_var->checked_type_ = relay_func->checked_type(); // Fusion over tupled results may leave identity relationships @@ -319,7 +322,7 @@ class ScheduleBuilder : ExprVisitor { te::Schedule schedule{nullptr}; tir::PrimFunc prim_func{nullptr}; // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr && create_schedule_) { + if (anchor_attrs_.as() == nullptr) { if (use_auto_scheduler_) { const auto* fauto_schedule = runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); @@ -330,7 +333,7 @@ class ScheduleBuilder : ExprVisitor { schedule = Downcast(obj); } } - if (use_meta_schedule_) { + if (backend::IsMetaScheduleEnabled()) { prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); Optional opt_mod_or_base_func = meta_schedule::MetaScheduleContext::QueryInsideWithScope( @@ -368,18 +371,16 @@ class ScheduleBuilder : ExprVisitor { ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - if (create_schedule_) { - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - } + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; } } @@ -389,8 +390,6 @@ class ScheduleBuilder : ExprVisitor { Attrs anchor_attrs_; int anchor_op_pattern_{0}; bool use_auto_scheduler_; - bool use_meta_schedule_; - bool create_schedule_; }; /*! @@ -775,9 +774,12 @@ std::string GetUniqueName(std::string name, std::unordered_map } TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { - return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { - return name; - }); + auto tgt = tvm::Target("ext_dev"); + LowerToTECompute lower_te_compute(tgt); + auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; }); + return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_, + outputs, te::Schedule(), tir::PrimFunc(), {}, + IRModule(Map({})), lower_te_compute.constant_tensors_); }); } // namespace tec