diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index abab8cc6e0a02..a1de51de728d0 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -45,6 +45,7 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" +#include "tvm/relay/op_strategy.h" #include "utils.h" namespace tvm { @@ -115,99 +116,24 @@ Array GetShape(const Array& shape) { } // Construct a schedule for a given Relay primitive function and target. -class ScheduleBuilder : public backend::MemoizedExprTranslator> { +class LowerToTECompute : public backend::MemoizedExprTranslator> { public: - explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), - device_copy_op_(Op::Get("device_copy")), - create_schedule_(create_schedule) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - use_meta_schedule_ = backend::IsMetaScheduleEnabled(); - } + explicit LowerToTECompute(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) {} - CachedFunc Create(const Function& relay_func, std::function renamer) { - Array fn_inputs; + Array Lower(const Function& relay_func, + std::function renamer) { for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - fn_inputs.push_back(tensor); inputs.push_back(tensor); + fn_inputs_.push_back(tensor); } memo_[param] = inputs; } readable_name_stream_ << "fused"; - auto outputs = this->VisitExpr(relay_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME - // whenever the value of kMaxFuncNameLength changes - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - - // TODO(mbs): This should be the definitive global by which the PrimFunc is known and - // no other GlobalVar ctors should appear inside the lowering machinery. - auto prim_fn_var = GlobalVar(renamer(candidate_name)); - prim_fn_var->checked_type_ = relay_func->checked_type(); - - // Fusion over tupled results may leave identity relationships - // between inputs and outputs, and those should not be scheduled. - // Hence schedule only non PlaceholderOp outputs. - tvm::Array tensor_outs; - for (const auto& tensor : outputs) { - if (!tensor->op.as()) { - tensor_outs.push_back(tensor); - } - } - - te::Schedule schedule{nullptr}; - tir::PrimFunc prim_func{nullptr}; - // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr && create_schedule_) { - if (use_auto_scheduler_) { - const auto* fauto_schedule = - runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); - ICHECK(fauto_schedule != nullptr) - << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); - if (obj.defined()) { - schedule = Downcast(obj); - } - } - if (use_meta_schedule_) { - prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); - Optional opt_mod_or_base_func = - meta_schedule::MetaScheduleContext::QueryInsideWithScope( - prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, - Array{IRModule({{prim_fn_var, prim_func}})}); - if (const auto* result = opt_mod_or_base_func.as()) { - prim_func = GetRef(result); - } else { - prim_func = tir::PrimFunc(nullptr); - } - } - - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined() && !prim_func.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } - if (schedule.defined()) { - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); - } - } - } - } - - return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, - IRModule(Map({})), constant_tensors_); + return this->VisitExpr(relay_func->body); } Array VisitExpr_(const VarNode* op) final { @@ -254,7 +180,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -278,28 +203,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - Array outputs; - OpImplementation impl; // TODO(mbs): device_copy cleanup ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; + Array outputs = lowered_out->outputs; + anchor_implementation_ = lowered_out->implementation; - if (create_schedule_) { - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; - } - } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); ICHECK(tuple_type) << "Expected output to be a tuple type " @@ -308,8 +218,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator ICHECK_EQ(tuple_type->fields.size(), outputs.size()); } - // TODO(mbs): device_copy cleanup - ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; readable_name_stream_ << '_' << op->name; return outputs; } @@ -347,27 +255,146 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator return {tuple[op->index]}; } + public: + // Additional outputs + Array fn_inputs_; + Array scalars_; + std::unordered_map constant_tensors_; + std::ostringstream readable_name_stream_; + OpImplementation anchor_implementation_; + + private: + tvm::Target target_; + // Index of the global constants + static int const_index; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; +}; + +int LowerToTECompute::const_index = 0; + +// Construct a schedule for a given Relay primitive function and target. +class ScheduleBuilder : ExprVisitor { + public: + explicit ScheduleBuilder(Target target, bool create_schedule = true) + : target_(target), + + create_schedule_(create_schedule) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + use_meta_schedule_ = backend::IsMetaScheduleEnabled(); + } + + CachedFunc Create(const Function& relay_func, std::function renamer) { + LowerToTECompute lower_te_compute(target_); + Array outputs = lower_te_compute.Lower(relay_func, renamer); + std::string candidate_name = lower_te_compute.readable_name_stream_.str(); + VisitExpr(relay_func->body); + + constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // TODO(mbs): This should be the definitive global by which the PrimFunc is known and + // no other GlobalVar ctors should appear inside the lowering machinery. + auto prim_fn_var = GlobalVar(renamer(candidate_name)); + prim_fn_var->checked_type_ = relay_func->checked_type(); + + // Fusion over tupled results may leave identity relationships + // between inputs and outputs, and those should not be scheduled. + // Hence schedule only non PlaceholderOp outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + + te::Schedule schedule{nullptr}; + tir::PrimFunc prim_func{nullptr}; + // No need to register schedule for device copy op. + if (anchor_attrs_.as() == nullptr && create_schedule_) { + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + if (use_meta_schedule_) { + prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs); + Optional opt_mod_or_base_func = + meta_schedule::MetaScheduleContext::QueryInsideWithScope( + prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, + Array{IRModule({{prim_fn_var, prim_func}})}); + if (const auto* result = opt_mod_or_base_func.as()) { + prim_func = GetRef(result); + } else { + prim_func = tir::PrimFunc(nullptr); + } + } + + // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined() && !prim_func.defined()) { + ICHECK(lower_te_compute.anchor_implementation_.defined()); + schedule = + lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } + if (schedule.defined()) { + for (const auto& scalar : lower_te_compute.scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } + } + } + } + + return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule, + prim_func, {}, IRModule(Map({})), + lower_te_compute.constant_tensors_); + } + + void VisitExpr_(const CallNode* call_node) final { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + if (create_schedule_) { + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + } + } + } + private: tvm::Target target_; Op anchor_op_; Attrs anchor_attrs_; int anchor_op_pattern_{0}; - OpImplementation anchor_implementation_; - std::ostringstream readable_name_stream_; - Array scalars_; - std::unordered_map constant_tensors_; bool use_auto_scheduler_; bool use_meta_schedule_; - // Cache device copy op for equivalence checking to reduce registry lookup - // overhead for each invocation of call node when retrieving schedules. - const Op& device_copy_op_; bool create_schedule_; - // Index of the global constants - static int const_index; }; -int ScheduleBuilder::const_index = 0; - /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered.