diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 9a0d2a2d7f439..57f299be60244 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -195,6 +195,7 @@ class TECompilerImpl : public TECompilerNode { auto target = Target("ext_dev"); auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); + ir_module->Add(global_var, key->source_func); value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); return value; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index bbca9b4c05fac..d97384bd12953 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -47,8 +47,6 @@ #include "../../../target/source/codegen_source_base.h" #include "../../op/op_common.h" #include "../../transforms/pass_utils.h" -#include "../te_compiler.h" -#include "../te_compiler_cache.h" #include "../utils.h" #include "compiler.h" @@ -466,7 +464,7 @@ class VMFunctionCompiler : ExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function CCacheKey key(func, target_host_); - auto cfunc = compiler_->LowerShapeFunc(key); + auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context ICHECK_EQ(cfunc->funcs->functions.size(), 1); @@ -552,7 +550,7 @@ class VMFunctionCompiler : ExprFunctor { CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; - auto cfunc = compiler_->Lower(key, mangle_fn); + auto cfunc = context_->compiler->Lower(key, mangle_fn); auto op_index = -1; if (func->GetAttr(attr::kCompiler).defined()) { @@ -858,8 +856,6 @@ class VMFunctionCompiler : ExprFunctor { size_t last_register_; /*! \brief Total number of virtual registers allocated. */ size_t registers_num_; - /*! \brief Compiler engine to lower primitive functions. */ - TECompiler compiler_; /*! \brief Global shared meta data */ VMCompilerContext* context_; /*! \brief Target devices. */ @@ -1185,8 +1181,7 @@ void VMCompiler::Codegen() { } } - TECompiler compiler; - auto ext_mods = compiler->LowerExternalFunctions(); + auto ext_mods = context_.compiler->LowerExternalFunctions(); runtime::Module lib; if (funcs.size() > 0) { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 096f286d8d92d..a05c52ced07f9 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -76,6 +76,8 @@ struct VMCompilerContext { TagMap tag_map; // Map from global var to a unique integer GlobalMap global_map; + // TEcompiler for lowering + tec::TECompiler compiler; // List of constants std::vector constants; // Device type for constants