diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 9e88222059cc9..793332f9d565c 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -42,7 +42,11 @@ #include #include #include +#if TVM_LLVM_VERSION < 180 #include +#else +#include +#endif #include #include #include @@ -252,8 +256,23 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) { } } - // Target options + // LLVM JIT engine options + if (const Optional& v = target->GetAttr("jit")) { + String value = v.value(); + if ((value == "mcjit") || (value == "orcjit")) { + jit_engine_ = value; + } else { + LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or `orcjit`)."; + } + } + + // RISCV code model + auto arch = llvm::Triple(triple_).getArch(); + if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) { + code_model_ = llvm::CodeModel::Medium; + } + // Target options #if TVM_LLVM_VERSION < 50 target_options_.LessPreciseFPMADOption = true; #endif @@ -521,6 +540,10 @@ std::string LLVMTargetInfo::str() const { os << quote << Join(",", opts) << quote; } + if (jit_engine_ != "mcjit") { + os << " -jit=" << jit_engine_; + } + return os.str(); } diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index 030a7db7210f3..f3948b7a01d29 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -212,6 +212,11 @@ class LLVMTargetInfo { * \return `llvm::FastMathFlags` for this target */ llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; } + /*! + * \brief Get the LLVM JIT engine type + * \return the type name of the JIT engine (default "mcjit" or "orcjit") + */ + const std::string GetJITEngine() const { return jit_engine_; } /*! * \brief Get the LLVM optimization level * \return optimization level for this target @@ -324,6 +329,7 @@ class LLVMTargetInfo { llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_; llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small; std::shared_ptr target_machine_; + std::string jit_engine_ = "mcjit"; }; /*! diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 4823f2c9eade0..30ce7e4436cfe 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -30,7 +30,10 @@ #include #include #include -#include // Force linking of MCJIT +#include +#include +#include +#include #include #include #include @@ -41,7 +44,11 @@ #include #include #include +#if TVM_LLVM_VERSION < 180 #include +#else +#include +#endif #include #include #include @@ -109,8 +116,11 @@ class LLVMModuleNode final : public runtime::ModuleNode { bool ImplementsFunction(const String& name, bool query_imports) final; + void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; } + private: - void LazyInitJIT(); + void InitMCJIT(); + void InitORCJIT(); bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const; void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const; void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const; @@ -119,8 +129,9 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::unique_ptr llvm_instance_; // JIT lock std::mutex mutex_; - // execution engine - llvm::ExecutionEngine* ee_{nullptr}; + // jit execution engines + llvm::ExecutionEngine* mcjit_ee_{nullptr}; + std::unique_ptr orcjit_ee_{nullptr}; // The raw pointer to the module. llvm::Module* module_{nullptr}; // The unique_ptr owning the module. This becomes empty once JIT has been initialized @@ -128,12 +139,21 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::unique_ptr module_owning_ptr_; /* \brief names of the external functions declared in this module */ Array function_names_; + std::string jit_engine_; }; LLVMModuleNode::~LLVMModuleNode() { - if (ee_ != nullptr) { - ee_->runStaticConstructorsDestructors(true); - delete ee_; + if (mcjit_ee_ != nullptr) { + mcjit_ee_->runStaticConstructorsDestructors(true); + delete mcjit_ee_; + } + if (orcjit_ee_ != nullptr) { + auto dtors = llvm::orc::getDestructors(*module_); + auto dtorRunner = std::make_unique(orcjit_ee_->getMainJITDylib()); + dtorRunner->add(dtors); + auto err = dtorRunner->run(); + ICHECK(!err) << llvm::toString(std::move(err)); + orcjit_ee_.reset(); } module_owning_ptr_.reset(); } @@ -162,7 +182,9 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtr lock(mutex_); @@ -349,6 +371,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { module_owning_ptr_ = cg->Finish(); module_ = module_owning_ptr_.get(); + jit_engine_ = llvm_target->GetJITEngine(); llvm_target->SetTargetMetadata(module_); module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); @@ -381,13 +404,16 @@ bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports) return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); } -void LLVMModuleNode::LazyInitJIT() { +void LLVMModuleNode::InitMCJIT() { std::lock_guard lock(mutex_); - if (ee_) { + if (mcjit_ee_) { return; } + // MCJIT builder With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); llvm::EngineBuilder builder(std::move(module_owning_ptr_)); + + // set options builder.setEngineKind(llvm::EngineKind::JIT); #if TVM_LLVM_VERSION <= 170 builder.setOptLevel(llvm::CodeGenOpt::Aggressive); @@ -397,18 +423,31 @@ void LLVMModuleNode::LazyInitJIT() { builder.setMCPU(llvm_target->GetCPU()); builder.setMAttrs(llvm_target->GetTargetFeatures()); builder.setTargetOptions(llvm_target->GetTargetOptions()); + + // create the taget machine auto tm = std::unique_ptr(builder.selectTarget()); if (!IsCompatibleWithHost(tm.get())) { LOG(FATAL) << "Cannot run module, architecture mismatch"; } + + // data layout llvm::DataLayout layout(tm->createDataLayout()); ICHECK(layout == module_->getDataLayout()) << "Data layout mismatch between module(" << module_->getDataLayout().getStringRepresentation() << ")" << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; - ee_ = builder.create(tm.release()); - ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple(); - ee_->runStaticConstructorsDestructors(false); + + // create MCJIT + mcjit_ee_ = builder.create(tm.release()); + ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for " + << module_->getTargetTriple(); + + VLOG(2) << "LLVM MCJIT execute " << module_->getModuleIdentifier() << " for triple `" + << llvm_target->GetTargetTriple() << "`" + << " on cpu `" << llvm_target->GetCPU() << "`"; + + // run ctors + mcjit_ee_->runStaticConstructorsDestructors(false); if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) { @@ -421,7 +460,104 @@ void LLVMModuleNode::LazyInitJIT() { // lead to a runtime crash. // Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize // all loaded objects, which will resolve symbols in JITed code. - ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91"); + mcjit_ee_->getFunctionAddress( + "__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91"); +} + +void LLVMModuleNode::InitORCJIT() { + std::lock_guard lock(mutex_); + if (orcjit_ee_) { + return; + } + // ORCJIT builder + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); + llvm::orc::JITTargetMachineBuilder tm_builder(llvm::Triple(llvm_target->GetTargetTriple())); + + // set options + tm_builder.setCPU(llvm_target->GetCPU()); + tm_builder.setFeatures(llvm_target->GetTargetFeatureString()); + tm_builder.setOptions(llvm_target->GetTargetOptions()); +#if TVM_LLVM_VERSION <= 170 + tm_builder.setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive); +#else + tm_builder.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive); +#endif + + // create the taget machine + std::unique_ptr tm = llvm::cantFail(tm_builder.createTargetMachine()); + if (!IsCompatibleWithHost(tm.get())) { + LOG(FATAL) << "Cannot run module, architecture mismatch"; + } + + // data layout + String module_name = module_->getModuleIdentifier(); + llvm::DataLayout layout(tm->createDataLayout()); + ICHECK(layout == module_->getDataLayout()) + << "Data layout mismatch between module(" + << module_->getDataLayout().getStringRepresentation() << ")" + << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; + + // compiler + const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder&) + -> llvm::Expected> { + return std::make_unique(std::move(tm)); + }; + +#if TVM_LLVM_VERSION >= 130 + // linker + const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const llvm::Triple&) { + return std::make_unique(session); + }; +#endif + + // create LLJIT + orcjit_ee_ = llvm::cantFail(llvm::orc::LLJITBuilder() +#if TVM_LLVM_VERSION >= 110 + .setDataLayout(layout) +#endif + .setCompileFunctionCreator(compilerBuilder) +#if TVM_LLVM_VERSION >= 130 + .setObjectLinkingLayerCreator(linkerBuilder) +#endif + .create()); + + ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine for " + << module_->getTargetTriple(); + + // store ctors + auto ctors = llvm::orc::getConstructors(*module_); + llvm::orc::CtorDtorRunner ctorRunner(orcjit_ee_->getMainJITDylib()); + ctorRunner.add(ctors); + + // resolve system symbols (like pthread, dl, m, etc.) + auto gen = + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(layout.getGlobalPrefix()); + ICHECK(gen) << llvm::toString(gen.takeError()) << "\n"; + orcjit_ee_->getMainJITDylib().addGenerator(std::move(gen.get())); + + // transfer module to a clone + auto uctx = std::make_unique(); + auto umod = llvm::CloneModule(*(std::move(module_owning_ptr_))); + + // add the llvm module to run + llvm::orc::ThreadSafeModule tsm(std::move(umod), std::move(uctx)); + auto err = orcjit_ee_->addIRModule(std::move(tsm)); + ICHECK(!err) << llvm::toString(std::move(err)); + + VLOG(2) << "LLVM ORCJIT execute " << module_->getModuleIdentifier() << " for triple `" + << llvm_target->GetTargetTriple() << "`" + << " on cpu `" << llvm_target->GetCPU() << "`"; + + // run ctors + err = ctorRunner.run(); + ICHECK(!err) << llvm::toString(std::move(err)); + + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) { + *ctx_addr = this; + } + runtime::InitContextFunctions( + [this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); }); } bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const { @@ -439,20 +575,40 @@ bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const { void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const { // first verifies if GV exists. if (module_->getGlobalVariable(name) != nullptr) { - return reinterpret_cast(ee_->getGlobalValueAddress(name)); - } else { - return nullptr; + if (jit_engine_ == "mcjit") { + return reinterpret_cast(mcjit_ee_->getGlobalValueAddress(name)); + } else if (jit_engine_ == "orcjit") { +#if TVM_LLVM_VERSION >= 150 + auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue(); +#else + auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress(); +#endif + return reinterpret_cast(addr); + } else { + LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized."; + } } + return nullptr; } void* LLVMModuleNode::GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const { // first verifies if GV exists. if (module_->getFunction(name) != nullptr) { - return reinterpret_cast(ee_->getFunctionAddress(name)); - } else { - return nullptr; + if (jit_engine_ == "mcjit") { + return reinterpret_cast(mcjit_ee_->getFunctionAddress(name)); + } else if (jit_engine_ == "orcjit") { +#if TVM_LLVM_VERSION >= 150 + auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue(); +#else + auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress(); +#endif + return reinterpret_cast(addr); + } else { + LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized."; + } } + return nullptr; } TVM_REGISTER_GLOBAL("target.build.llvm") @@ -473,6 +629,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") module->setTargetTriple(llvm_target->GetTargetTriple()); module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); n->Init(std::move(module), std::move(llvm_instance)); + n->SetJITEngine(llvm_target->GetJITEngine()); return runtime::Module(n); }); @@ -592,6 +749,7 @@ TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module { auto n = make_object(); + n->SetJITEngine("mcjit"); n->LoadIR(filename); return runtime::Module(n); }); @@ -613,6 +771,7 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") std::unique_ptr blob = CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix); n->Init(std::move(blob), std::move(llvm_instance)); + n->SetJITEngine(llvm_target->GetJITEngine()); return runtime::Module(n); }); @@ -642,6 +801,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata auto n = make_object(); n->Init(std::move(mod), std::move(llvm_instance)); + n->SetJITEngine(llvm_target->GetJITEngine()); auto meta_mod = MetadataModuleCreate(metadata); meta_mod->Import(runtime::Module(n)); @@ -688,6 +848,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module auto n = make_object(); n->Init(std::move(mod), std::move(llvm_instance)); + n->SetJITEngine(llvm_target->GetJITEngine()); for (auto m : modules) { n->Import(m); } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index aa4499ec9667d..28c7e066291f4 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -291,6 +291,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("opt-level") // LLVM command line flags, see below .add_attr_option>("cl-opt") + // LLVM JIT engine mcjit/orcjit + .add_attr_option("jit") .set_default_keys({"cpu"}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. diff --git a/tests/python/runtime/test_runtime_module_based_interface.py b/tests/python/runtime/test_runtime_module_based_interface.py index 6e62e3f2155cc..55edbdaccb7d9 100644 --- a/tests/python/runtime/test_runtime_module_based_interface.py +++ b/tests/python/runtime/test_runtime_module_based_interface.py @@ -23,6 +23,7 @@ from tvm.contrib.debugger import debug_executor from tvm.contrib.cuda_graph import cuda_graph_executor import tvm.testing +import pytest def input_shape(mod): @@ -48,10 +49,11 @@ def verify(data): @tvm.testing.requires_llvm -def test_legacy_compatibility(): +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +def test_legacy_compatibility(target): mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) + graph, lib, graph_params = relay.build_module.build(mod, target, params=params) data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") dev = tvm.cpu() module = graph_executor.create(graph, lib, dev) @@ -63,10 +65,11 @@ def test_legacy_compatibility(): @tvm.testing.requires_llvm -def test_cpu(): +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +def test_cpu(target): mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + complied_graph_lib = relay.build_module.build(mod, target, params=params) data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") # raw api dev = tvm.cpu() @@ -105,10 +108,11 @@ def test_cpu_get_graph_json(): @tvm.testing.requires_llvm -def test_cpu_get_graph_params_run(): +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +def test_cpu_get_graph_params_run(target): mod, params = relay.testing.synthetic.get_workload() with tvm.transform.PassContext(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + complied_graph_lib = relay.build_module.build(mod, target, params=params) data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") dev = tvm.cpu() from tvm.contrib import utils @@ -584,10 +588,11 @@ def verify_rpc_gpu_remove_package_params(obj_format): @tvm.testing.requires_llvm -def test_debug_graph_executor(): +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +def test_debug_graph_executor(target): mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + complied_graph_lib = relay.build_module.build(mod, target, params=params) data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") # raw api diff --git a/tests/python/runtime/test_runtime_module_load.py b/tests/python/runtime/test_runtime_module_load.py index ecaa7067a5a03..3789a1d0907d7 100644 --- a/tests/python/runtime/test_runtime_module_load.py +++ b/tests/python/runtime/test_runtime_module_load.py @@ -22,6 +22,7 @@ import subprocess import tvm.testing from tvm.relay.backend import Runtime +import pytest runtime_py = """ import os @@ -42,9 +43,9 @@ """ -def test_dso_module_load(): - if not tvm.testing.device_enabled("llvm"): - return +@tvm.testing.requires_llvm +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +def test_dso_module_load(target): dtype = "int64" temp = utils.tempdir() @@ -63,7 +64,7 @@ def save_object(names): mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main") ) - m = tvm.driver.build(mod, target="llvm") + m = tvm.driver.build(mod, target=target) for name in names: m.save(name) @@ -167,6 +168,7 @@ def check_stackvm(device): check_stackvm(device) +@tvm.testing.requires_llvm def test_combine_module_llvm(): """Test combine multiple module into one shared lib.""" # graph @@ -178,9 +180,6 @@ def test_combine_module_llvm(): def check_llvm(): dev = tvm.cpu(0) - if not tvm.testing.device_enabled("llvm"): - print("Skip because llvm is not enabled") - return temp = utils.tempdir() fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1") fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2") diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index d5e8d060254e9..83bd8649700bb 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -171,6 +171,13 @@ def test_target_llvm_options(): ) +def test_target_llvm_jit_options(): + target = tvm.target.Target("llvm -jit=mcjit") + assert target.attrs["jit"] == "mcjit" + target = tvm.target.Target("llvm -jit=orcjit") + assert target.attrs["jit"] == "orcjit" + + def test_target_create(): targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), vta(), bifrost()] for tgt in targets: