diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index abab8cc6e0a0..ffcce6e1c8da 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -28,11 +28,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -114,100 +116,40 @@ Array GetShape(const Array& shape) { return res; } -// Construct a schedule for a given Relay primitive function and target. -class ScheduleBuilder : public backend::MemoizedExprTranslator> { +// Lowers Relay primitive Function to TE Compute +class LowerToTECompute : public backend::MemoizedExprTranslator> { public: - explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), - device_copy_op_(Op::Get("device_copy")), - create_schedule_(create_schedule) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - use_meta_schedule_ = backend::IsMetaScheduleEnabled(); - } + explicit LowerToTECompute(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) {} - CachedFunc Create(const Function& relay_func, std::function renamer) { - Array fn_inputs; + Array Lower(const Function& relay_func, + std::function renamer) { for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - fn_inputs.push_back(tensor); inputs.push_back(tensor); + fn_inputs_.push_back(tensor); } memo_[param] = inputs; } readable_name_stream_ << "fused"; - auto outputs = this->VisitExpr(relay_func->body); - auto candidate_name = readable_name_stream_.str(); + + Array outputs = this->VisitExpr(relay_func->body); + + candidate_name_ = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME // whenever the value of kMaxFuncNameLength changes - if (candidate_name.size() > kMaxFuncNameLength) { + if (candidate_name_.size() > kMaxFuncNameLength) { std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - - // TODO(mbs): This should be the definitive global by which the PrimFunc is known and - // no other GlobalVar ctors should appear inside the lowering machinery. - auto prim_fn_var = GlobalVar(renamer(candidate_name)); - prim_fn_var->checked_type_ = relay_func->checked_type(); - - // Fusion over tupled results may leave identity relationships - // between inputs and outputs, and those should not be scheduled. - // Hence schedule only non PlaceholderOp outputs. - tvm::Array tensor_outs; - for (const auto& tensor : outputs) { - if (!tensor->op.as()) { - tensor_outs.push_back(tensor); - } - } - - te::Schedule schedule{nullptr}; - tir::PrimFunc prim_func{nullptr}; - // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr && create_schedule_) { - if (use_auto_scheduler_) { - const auto* fauto_schedule = - runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); - ICHECK(fauto_schedule != nullptr) - << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); - if (obj.defined()) { - schedule = Downcast(obj); - } - } - if (use_meta_schedule_) { - prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); - Optional opt_mod_or_base_func = - meta_schedule::MetaScheduleContext::QueryInsideWithScope( - prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, - Array{IRModule({{prim_fn_var, prim_func}})}); - if (const auto* result = opt_mod_or_base_func.as()) { - prim_func = GetRef(result); - } else { - prim_func = tir::PrimFunc(nullptr); - } - } - - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined() && !prim_func.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } - if (schedule.defined()) { - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); - } - } - } + truncated_name << candidate_name_.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hex << std::hash{}(candidate_name_) << "_"; + candidate_name_ = truncated_name.str(); } - return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, - IRModule(Map({})), constant_tensors_); + return outputs; } Array VisitExpr_(const VarNode* op) final { @@ -254,7 +196,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -278,28 +219,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - Array outputs; - OpImplementation impl; // TODO(mbs): device_copy cleanup ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; - - if (create_schedule_) { - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; - } - } + Array outputs = lowered_out->outputs; + op_implementations_[op.operator->()] = lowered_out->implementation; + if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); ICHECK(tuple_type) << "Expected output to be a tuple type " @@ -308,8 +234,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator ICHECK_EQ(tuple_type->fields.size(), outputs.size()); } - // TODO(mbs): device_copy cleanup - ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; readable_name_stream_ << '_' << op->name; return outputs; } @@ -347,26 +271,131 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator return {tuple[op->index]}; } + public: + // Additional outputs + Array fn_inputs_; + Array scalars_; + std::unordered_map constant_tensors_; + std::unordered_map op_implementations_; + std::string candidate_name_; + private: tvm::Target target_; - Op anchor_op_; - Attrs anchor_attrs_; - int anchor_op_pattern_{0}; - OpImplementation anchor_implementation_; std::ostringstream readable_name_stream_; - Array scalars_; - std::unordered_map constant_tensors_; - bool use_auto_scheduler_; - bool use_meta_schedule_; + // Index of the global constants + static int const_index; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; - bool create_schedule_; - // Index of the global constants - static int const_index; }; -int ScheduleBuilder::const_index = 0; +int LowerToTECompute::const_index = 0; + +// Construct a schedule for a given Relay primitive function and target. +class ScheduleBuilder : public ExprVisitor { + public: + explicit ScheduleBuilder(Target target) : target_(target) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + } + + CachedFunc Create(const Function& relay_func, std::function renamer) { + LowerToTECompute lower_te_compute(target_); + Array outputs = lower_te_compute.Lower(relay_func, renamer); + Array fn_inputs = lower_te_compute.fn_inputs_; + VisitExpr(relay_func->body); + + // TODO(mbs): This should be the definitive global by which the PrimFunc is known and + // no other GlobalVar ctors should appear inside the lowering machinery. + auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_)); + prim_fn_var->checked_type_ = relay_func->checked_type(); + + // Fusion over tupled results may leave identity relationships + // between inputs and outputs, and those should not be scheduled. + // Hence schedule only non PlaceholderOp outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + + te::Schedule schedule{nullptr}; + tir::PrimFunc prim_func{nullptr}; + // No need to register schedule for device copy op. + if (anchor_attrs_.as() == nullptr) { + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + if (backend::IsMetaScheduleEnabled()) { + prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); + Optional opt_mod_or_base_func = + meta_schedule::MetaScheduleContext::QueryInsideWithScope( + prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, + Array{IRModule({{prim_fn_var, prim_func}})}); + if (const auto* result = opt_mod_or_base_func.as()) { + prim_func = GetRef(result); + } else { + prim_func = tir::PrimFunc(nullptr); + } + } + + // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined() && !prim_func.defined()) { + auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->()); + ICHECK(anchor_impl != lower_te_compute.op_implementations_.end()); + schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_); + } + if (schedule.defined()) { + for (const auto& scalar : lower_te_compute.scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } + } + } + } + + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, + IRModule(Map({})), lower_te_compute.constant_tensors_); + } + + void VisitExpr_(const CallNode* call_node) final { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + for (Expr arg : call_node->args) { + VisitExpr(arg); + } + + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + } + } + + private: + tvm::Target target_; + Op anchor_op_; + Attrs anchor_attrs_; + int anchor_op_pattern_{0}; + bool use_auto_scheduler_; +}; /*! * \brief Create schedule for target. @@ -750,9 +779,12 @@ std::string GetUniqueName(std::string name, std::unordered_map } TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { - return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { - return name; - }); + auto tgt = tvm::Target("ext_dev"); + LowerToTECompute lower_te_compute(tgt); + auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; }); + return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_, + outputs, te::Schedule(), tir::PrimFunc(), {}, + IRModule(Map({})), lower_te_compute.constant_tensors_); }); } // namespace tec