Skip to content

Commit

Permalink
fixed merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 10, 2022
1 parent 5c0d9af commit 3b97529
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,7 @@ int LowerToTECompute::const_index = 0;
class ScheduleBuilder : ExprVisitor {
public:
explicit ScheduleBuilder(Target target, bool create_schedule = true)
: target_(target),

create_schedule_(create_schedule) {
: target_(target), create_schedule_(create_schedule) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
Expand All @@ -289,6 +287,7 @@ class ScheduleBuilder : ExprVisitor {
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
LowerToTECompute lower_te_compute(target_);
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
std::string candidate_name = lower_te_compute.readable_name_stream_.str();
VisitExpr(relay_func->body);

Expand Down Expand Up @@ -332,7 +331,7 @@ class ScheduleBuilder : ExprVisitor {
}
}
if (use_meta_schedule_) {
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Expand All @@ -359,9 +358,8 @@ class ScheduleBuilder : ExprVisitor {
}
}

return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule,
prim_func, {}, IRModule(Map<GlobalVar, BaseFunc>({})),
lower_te_compute.constant_tensors_);
return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
}

void VisitExpr_(const CallNode* call_node) final {
Expand Down

0 comments on commit 3b97529

Please sign in to comment.