Skip to content

Commit

Permalink
[LLVM][RUNTIME] Make ORCJIT LLVM executor the default one
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 committed Nov 24, 2024
1 parent 4d99ec5 commit 7a09ef6
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/target/llvm/llvm_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
if ((value == "mcjit") || (value == "orcjit")) {
jit_engine_ = value;
} else {
LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or `orcjit`).";
LOG(FATAL) << "invalid jit option " << value << " (can be `orcjit` or `mcjit`).";
}
}

Expand Down Expand Up @@ -530,7 +530,7 @@ std::string LLVMTargetInfo::str() const {
os << quote << Join(",", opts) << quote;
}

if (jit_engine_ != "mcjit") {
if (jit_engine_ != "orcjit") {
os << " -jit=" << jit_engine_;
}

Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/llvm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class LLVMTargetInfo {
llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; }
/*!
* \brief Get the LLVM JIT engine type
* \return the type name of the JIT engine (default "mcjit" or "orcjit")
* \return the type name of the JIT engine (default "orcjit" or "mcjit")
*/
const std::string GetJITEngine() const { return jit_engine_; }
/*!
Expand Down Expand Up @@ -348,7 +348,7 @@ class LLVMTargetInfo {
llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
std::shared_ptr<llvm::TargetMachine> target_machine_;
std::string jit_engine_ = "mcjit";
std::string jit_engine_ = "orcjit";
};

/*!
Expand Down
16 changes: 13 additions & 3 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/DataLayout.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Intrinsics.h>
Expand Down Expand Up @@ -512,8 +513,17 @@ void LLVMModuleNode::InitORCJIT() {

#if TVM_LLVM_VERSION >= 130
// linker
const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const llvm::Triple&) {
return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
const auto linkerBuilder =
[&](llvm::orc::ExecutionSession& session,
const llvm::Triple& triple) -> std::unique_ptr<llvm::orc::ObjectLayer> {
auto GetMemMgr = []() { return std::make_unique<llvm::SectionMemoryManager>(); };
auto ObjLinkingLayer =
std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(session, std::move(GetMemMgr));
if (triple.isOSBinFormatCOFF()) {
ObjLinkingLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
ObjLinkingLayer->setAutoClaimResponsibilityForObjectSymbols(true);
}
return ObjLinkingLayer;
};
#endif

Expand Down Expand Up @@ -755,7 +765,7 @@ TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
.set_body_typed([](std::string filename, std::string fmt) -> runtime::Module {
auto n = make_object<LLVMModuleNode>();
n->SetJITEngine("mcjit");
n->SetJITEngine("orcjit");
n->LoadIR(filename);
return runtime::Module(n);
});
Expand Down
8 changes: 4 additions & 4 deletions tests/python/runtime/test_runtime_module_based_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def verify(data):


@tvm.testing.requires_llvm
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"])
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"])
def test_legacy_compatibility(target):
mod, params = relay.testing.synthetic.get_workload()
with relay.build_config(opt_level=3):
Expand All @@ -70,7 +70,7 @@ def test_legacy_compatibility(target):


@tvm.testing.requires_llvm
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"])
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"])
def test_cpu(target):
mod, params = relay.testing.synthetic.get_workload()
with relay.build_config(opt_level=3):
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_cpu_get_graph_json():


@tvm.testing.requires_llvm
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"])
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"])
def test_cpu_get_graph_params_run(target):
mod, params = relay.testing.synthetic.get_workload()
with tvm.transform.PassContext(opt_level=3):
Expand Down Expand Up @@ -592,7 +592,7 @@ def verify_rpc_gpu_remove_package_params(obj_format):


@tvm.testing.requires_llvm
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"])
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"])
def test_debug_graph_executor(target):
mod, params = relay.testing.synthetic.get_workload()
with relay.build_config(opt_level=3):
Expand Down
2 changes: 1 addition & 1 deletion tests/python/runtime/test_runtime_module_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


@tvm.testing.requires_llvm
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"])
@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"])
def test_dso_module_load(target):
dtype = "int64"
temp = utils.tempdir()
Expand Down

0 comments on commit 7a09ef6

Please sign in to comment.