diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index a1de51de728d0..2c2042859ddb4 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -278,9 +278,7 @@ int LowerToTECompute::const_index = 0; class ScheduleBuilder : ExprVisitor { public: explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), - - create_schedule_(create_schedule) { + : target_(target), create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); use_meta_schedule_ = backend::IsMetaScheduleEnabled(); @@ -289,6 +287,7 @@ class ScheduleBuilder : ExprVisitor { CachedFunc Create(const Function& relay_func, std::function renamer) { LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); + Array fn_inputs = lower_te_compute.fn_inputs_; std::string candidate_name = lower_te_compute.readable_name_stream_.str(); VisitExpr(relay_func->body); @@ -332,7 +331,7 @@ class ScheduleBuilder : ExprVisitor { } } if (use_meta_schedule_) { - prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs); + prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); Optional opt_mod_or_base_func = meta_schedule::MetaScheduleContext::QueryInsideWithScope( prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, @@ -359,9 +358,8 @@ class ScheduleBuilder : ExprVisitor { } } - return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule, - prim_func, {}, IRModule(Map({})), - lower_te_compute.constant_tensors_); + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, + IRModule(Map({})), lower_te_compute.constant_tensors_); } void VisitExpr_(const CallNode* call_node) final {