Skip to content

Commit

Permalink
[Relay] Replace compile engine with TE compiler in the VM (apache#8501)
Browse files Browse the repository at this point in the history
* [VM] Add imports to new TE in VM compiler

* [VM] Add comments to compile engine usages

* [VM] Replace depreceated CachedFunc of compile_engine with TE_compiler

* [VM] rm compiler engine compiler.cc

* [VM] Replace compile engine with TECompiler in memory allocator

* [VM] Add relay interface to te_compiler

* [Relay] Fix linting errors

* Move TEcompiler to VMCompilerContext; add global func into IRmodule when lowering in TEcompiler

* add back the check

* skip the check for ext func in tecompiler

* skip tvm::build for external functions

* trigger ci

* retrigger ci

* retrigger ci

* remove the unnecessary loop in tecompiler

Co-authored-by: YuchenJin <yuchenj@cs.washington.edu>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent cd2690d commit d2626a7
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 20 deletions.
5 changes: 5 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,11 @@ class RelayBuildModule : public runtime::ModuleNode {

auto lowered_funcs = executor_codegen_->GetIRModule();

// No need to build for external functions.
if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) {
lowered_funcs.Set("ext_dev", IRModule());
}

// Generate a placeholder function that attaches linked params as its arguments.
if (target_host->GetAttr<Bool>("link-params").value_or(Bool(false))) {
CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen.";
Expand Down
7 changes: 1 addition & 6 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class TECompilerImpl : public TECompilerNode {
auto target = Target("ext_dev");
auto global_var = GlobalVar(func_name);
global_var->checked_type_ = key->source_func->checked_type();
ir_module->Add(global_var, key->source_func);
value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
return value;
}
Expand Down Expand Up @@ -347,12 +348,6 @@ class LowerTensorExpr : public ExprMutator {
<< ext_func->prim_fn_var->name_hint;

Map<GlobalVar, tir::PrimFunc> prim_fns;

for (auto prim_fn : ext_func->funcs->functions) {
CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
}

relay::Function func_with_metadata = func;
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
Expand Down
14 changes: 5 additions & 9 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
#include <vector>

#include "../../../target/source/codegen_source_base.h"
#include "../../backend/compile_engine.h"
#include "../../op/op_common.h"
#include "../../transforms/pass_utils.h"
#include "../utils.h"
Expand Down Expand Up @@ -79,6 +78,7 @@ namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
using namespace relay::transform;
using namespace tec;

// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
Expand Down Expand Up @@ -253,7 +253,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
ExprDeviceMap expr_device_map)
: last_register_(0),
registers_num_(0),
engine_(CompileEngine::Global()),
context_(context),
target_host_(target_host),
expr_device_map_(std::move(expr_device_map)) {
Expand Down Expand Up @@ -465,7 +464,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
// Lower shape function
CCacheKey key(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key);
auto cfunc = context_->compiler->LowerShapeFunc(key);
int op_index = -1;
// pick the only function inside the context
ICHECK_EQ(cfunc->funcs->functions.size(), 1);
Expand Down Expand Up @@ -551,7 +550,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

CCacheKey key(func, target);
auto mangle_fn = [](String name) { return name; };
auto cfunc = engine_->Lower(key, mangle_fn);
auto cfunc = context_->compiler->Lower(key, mangle_fn);

auto op_index = -1;
if (func->GetAttr<String>(attr::kCompiler).defined()) {
Expand Down Expand Up @@ -857,8 +856,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
size_t last_register_;
/*! \brief Total number of virtual registers allocated. */
size_t registers_num_;
/*! \brief Compiler engine to lower primitive functions. */
CompileEngine engine_;
/*! \brief Global shared meta data */
VMCompilerContext* context_;
/*! \brief Target devices. */
Expand Down Expand Up @@ -1134,8 +1131,8 @@ void VMCompiler::Codegen() {
}
}

auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
auto ext_mods = context_.compiler->LowerExternalFunctions();

runtime::Module lib;
if (funcs.size() > 0) {
lib = tvm::build(funcs, target_host_);
Expand All @@ -1146,7 +1143,6 @@ void VMCompiler::Codegen() {
}
lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata());
exec_->SetLib(lib);
CompileEngine::Global()->Clear();
}

ExprDeviceMap VMCompiler::AnalyzeContext() const {
Expand Down
7 changes: 5 additions & 2 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@

#include "../../../runtime/vm/naive_allocator.h"
#include "../../../runtime/vm/profiler/vm.h"
#include "../../backend/compile_engine.h"
#include "../../transforms/pass_utils.h"
#include "../te_compiler.h"
#include "../te_compiler_cache.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -75,12 +76,14 @@ struct VMCompilerContext {
TagMap tag_map;
// Map from global var to a unique integer
GlobalMap global_map;
// TEcompiler for lowering
tec::TECompiler compiler;
// List of constants
std::vector<NDArray> constants;
// Device type for constants
std::vector<Index> const_device_type;
// List of cached functions
std::vector<CachedFunc> cached_funcs;
std::vector<tec::CachedFunc> cached_funcs;
// The functions that have been lowered.
std::unordered_map<tir::PrimFunc, size_t, ObjectPtrHash, ObjectPtrEqual> seen_funcs;
};
Expand Down
10 changes: 7 additions & 3 deletions src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@
#include <unordered_set>
#include <vector>

#include "../backend/compile_engine.h"
#include "../backend/te_compiler.h"
#include "../backend/te_compiler_cache.h"
#include "../op/memory/memory.h"
#include "../op/vm/vm.h"
#include "./pass_utils.h"
#include "let_list.h"
#include "pattern_utils.h"

using namespace tvm::runtime;
using namespace tvm::relay::tec;

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -271,9 +273,11 @@ class DialectRewriter : public ExprMutator {
Array<Expr> EmitShapeFunc(LetList* scope, const Function& func,
const std::vector<Expr>& new_args) {
Array<Expr> shape_func_ins;
auto engine = CompileEngine::Global();

TECompiler compiler;

CCacheKey key(func, target_host_);
auto cfunc = engine->LowerShapeFunc(key);
auto cfunc = compiler->LowerShapeFunc(key);
auto input_states = cfunc->shape_func_param_states;

Array<Integer> is_inputs;
Expand Down

0 comments on commit d2626a7

Please sign in to comment.