diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 7770e42086de..33a09b1ded66 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -230,11 +230,11 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false, false); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; - auto f = Downcast(kv.second); - cg->AddFunction(f); - } + cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { + ICHECK(kv.second->template IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + return Downcast(kv.second); + }); const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); Array bitcode_files = (*find_rocm_bitcodes)(); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index e9eacc27fc72..2f91807b6933 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -731,9 +731,8 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { } cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); - for (const PrimFunc& f : funcs) { - cg->AddFunction(f); - } + cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + if (!linked_params.empty()) { cg->LinkParameters(linked_params); } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 52c5b98a0025..a4f007aeebed 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -36,6 +36,7 @@ #include #include +#include #include #include #include @@ -92,6 +93,25 @@ class CodeGenLLVM : public ExprFunctor, * \return the created module. */ virtual std::unique_ptr Finish(); + /*! + * \brief Add functions from the (unordered) range to the current module in a deterministic order. + * The range consists of objects convertible to PrimFunc. + * \param begin The beginning of the range. + * \param end The end of the range. + * \param pfunc Converter function from the range element type to PrimFunc. + */ + template + void AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc); + /*! + * \brief Add functions from the (unordered) range of elements of type PrimFunc to the current + * module in a deterministic order. + * \param begin The beginning of the range. + * \param end The end of the range. + */ + template + void AddFunctionsOrdered(IterType begin, IterType end) { + this->AddFunctionsOrdered(begin, end, [](auto f) { return f; }); + } /*! * \brief Add mod to be linked with the generated module * \param mod The module to be linked. @@ -377,6 +397,22 @@ inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) { #endif } +template +void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc) { + std::vector funcs; + for (auto it = begin; it != end; ++it) { + funcs.push_back(pfunc(*it)); + } + std::sort(funcs.begin(), funcs.end(), [](PrimFunc func_a, PrimFunc func_b) { + std::string name_a = func_a->GetAttr(tvm::attr::kGlobalSymbol).value(); + std::string name_b = func_b->GetAttr(tvm::attr::kGlobalSymbol).value(); + return name_a < name_b; + }); + for (auto& f : funcs) { + AddFunction(f); + } +} + } // namespace codegen } // namespace tvm #endif // LLVM_VERSION diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 15543eda423f..ebe6d6d67442 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -274,11 +274,11 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false, false); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; - auto f = Downcast(kv.second); - cg->AddFunction(f); - } + cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { + ICHECK(kv.second->template IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + return Downcast(kv.second); + }); const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); if (flibdevice_path != nullptr) { diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 8bdf6d1b0422..0e4bca4396f5 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -258,9 +258,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); - for (const auto& f : funcs) { - cg->AddFunction(f); - } + cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 10cbcd68f362..e5e93ed2c940 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -818,5 +818,32 @@ def do_atomic_add(A): tvm.testing.assert_allclose(a.numpy(), ref, rtol=1e-5) +@tvm.testing.requires_llvm +def test_llvm_order_functions(): + """Check that functions in the LLVM module are ordered alphabetically.""" + + # Note: the order is alphabetical because that's a predictable ordering. Any predictable + # ordering will work fine, but if the ordering changes, this test will need to be updated. + def make_call_extern(caller, callee): + # Create a function: + # float32 caller(float32 v) { return callee(v); } + ib = tvm.tir.ir_builder.create() + v = tvm.te.var("v", dtype="float32") + t = tvm.tir.call_extern("float32", callee, v) + ib.emit(t) + return tvm.tir.PrimFunc([v], ib.get()).with_attr("global_symbol", caller) + + # Create some functions in a random order. + functions = { + "Danny": make_call_extern("Danny", "Dave"), + "Sammy": make_call_extern("Sammy", "Eve"), + "Kirby": make_call_extern("Kirby", "Fred"), + } + mod = tvm.IRModule(functions=functions) + ir_text = tvm.build(mod, None, target="llvm").get_source("ll") + matches = re.findall(r"^define[^@]*@([a-zA-Z_][a-zA-Z0-9_]*)", ir_text, re.MULTILINE) + assert matches == sorted(matches) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))