From 7a09ef6e9788b4c98a8d469786e45fa479110d43 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Mon, 25 Nov 2024 00:04:14 +0200 Subject: [PATCH] [LLVM][RUNTIME] Make ORCJIT LLVM executor the default one --- src/target/llvm/llvm_instance.cc | 4 ++-- src/target/llvm/llvm_instance.h | 4 ++-- src/target/llvm/llvm_module.cc | 16 +++++++++++++--- .../test_runtime_module_based_interface.py | 8 ++++---- tests/python/runtime/test_runtime_module_load.py | 2 +- 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 0406dcf951bb..e2c5e28592b7 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -269,7 +269,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) if ((value == "mcjit") || (value == "orcjit")) { jit_engine_ = value; } else { - LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or `orcjit`)."; + LOG(FATAL) << "invalid jit option " << value << " (can be `orcjit` or `mcjit`)."; } } @@ -530,7 +530,7 @@ std::string LLVMTargetInfo::str() const { os << quote << Join(",", opts) << quote; } - if (jit_engine_ != "mcjit") { + if (jit_engine_ != "orcjit") { os << " -jit=" << jit_engine_; } diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index add2af6002c6..5cea99403a0b 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -232,7 +232,7 @@ class LLVMTargetInfo { llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; } /*! * \brief Get the LLVM JIT engine type - * \return the type name of the JIT engine (default "mcjit" or "orcjit") + * \return the type name of the JIT engine (default "orcjit" or "mcjit") */ const std::string GetJITEngine() const { return jit_engine_; } /*! @@ -348,7 +348,7 @@ class LLVMTargetInfo { llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_; llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small; std::shared_ptr target_machine_; - std::string jit_engine_ = "mcjit"; + std::string jit_engine_ = "orcjit"; }; /*! diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 34bbb6a0c6a9..98dbe139f1cd 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -512,8 +513,17 @@ void LLVMModuleNode::InitORCJIT() { #if TVM_LLVM_VERSION >= 130 // linker - const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const llvm::Triple&) { - return std::make_unique(session); + const auto linkerBuilder = + [&](llvm::orc::ExecutionSession& session, + const llvm::Triple& triple) -> std::unique_ptr { + auto GetMemMgr = []() { return std::make_unique(); }; + auto ObjLinkingLayer = + std::make_unique(session, std::move(GetMemMgr)); + if (triple.isOSBinFormatCOFF()) { + ObjLinkingLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); + ObjLinkingLayer->setAutoClaimResponsibilityForObjectSymbols(true); + } + return ObjLinkingLayer; }; #endif @@ -755,7 +765,7 @@ TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module { auto n = make_object(); - n->SetJITEngine("mcjit"); + n->SetJITEngine("orcjit"); n->LoadIR(filename); return runtime::Module(n); }); diff --git a/tests/python/runtime/test_runtime_module_based_interface.py b/tests/python/runtime/test_runtime_module_based_interface.py index 3f712587684d..2c46838b942a 100644 --- a/tests/python/runtime/test_runtime_module_based_interface.py +++ b/tests/python/runtime/test_runtime_module_based_interface.py @@ -54,7 +54,7 @@ def verify(data): @tvm.testing.requires_llvm -@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"]) def test_legacy_compatibility(target): mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): @@ -70,7 +70,7 @@ def test_legacy_compatibility(target): @tvm.testing.requires_llvm -@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"]) def test_cpu(target): mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): @@ -113,7 +113,7 @@ def test_cpu_get_graph_json(): @tvm.testing.requires_llvm -@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"]) def test_cpu_get_graph_params_run(target): mod, params = relay.testing.synthetic.get_workload() with tvm.transform.PassContext(opt_level=3): @@ -592,7 +592,7 @@ def verify_rpc_gpu_remove_package_params(obj_format): @tvm.testing.requires_llvm -@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"]) def test_debug_graph_executor(target): mod, params = relay.testing.synthetic.get_workload() with relay.build_config(opt_level=3): diff --git a/tests/python/runtime/test_runtime_module_load.py b/tests/python/runtime/test_runtime_module_load.py index 3789a1d0907d..87a8ef9f5e12 100644 --- a/tests/python/runtime/test_runtime_module_load.py +++ b/tests/python/runtime/test_runtime_module_load.py @@ -44,7 +44,7 @@ @tvm.testing.requires_llvm -@pytest.mark.parametrize("target", ["llvm", "llvm -jit=orcjit"]) +@pytest.mark.parametrize("target", ["llvm", "llvm -jit=mcjit"]) def test_dso_module_load(target): dtype = "int64" temp = utils.tempdir()