Skip to content

Commit

Permalink
[LLVM/CG] Sort PrimFuncs when creating LLVM module (apache#8958)
Browse files Browse the repository at this point in the history
* [LLVM/CG] Sort PrimFuncs when creating LLVM module

PrimFuncs are stored in a map where the order of iteration is not
deterministic. This can cause a different llvm::Module to be created
each time, which can defeat debugging tools like -opt-bisect-limit.

Add function CodeGenLLVM::AddFunctionsOrdered that takes a range of
PrimFuncs or objects convertible to PrimFuncs, and adds them to the
LLVM module in a deterministic order.

* Empty commit to restart build

* Add testcase
  • Loading branch information
Krzysztof Parzyszek authored and ylc committed Sep 29, 2021
1 parent bba356f commit e7847ca
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 16 deletions.
10 changes: 5 additions & 5 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,11 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) {

cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false, false);

for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
cg->AddFunction(f);
}
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
return Downcast<PrimFunc>(kv.second);
});

const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
Array<runtime::String> bitcode_files = (*find_rocm_bitcodes)();
Expand Down
5 changes: 2 additions & 3 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,8 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
}

cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false);
for (const PrimFunc& f : funcs) {
cg->AddFunction(f);
}
cg->AddFunctionsOrdered(funcs.begin(), funcs.end());

if (!linked_params.empty()) {
cg->LinkParameters(linked_params);
}
Expand Down
36 changes: 36 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>

#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -92,6 +93,25 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
* \return the created module.
*/
virtual std::unique_ptr<llvm::Module> Finish();
/*!
* \brief Add functions from the (unordered) range to the current module in a deterministic order.
* The range consists of objects convertible to PrimFunc.
* \param begin The beginning of the range.
* \param end The end of the range.
* \param pfunc Converter function from the range element type to PrimFunc.
*/
template <typename IterType, typename ConvType>
void AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc);
/*!
* \brief Add functions from the (unordered) range of elements of type PrimFunc to the current
* module in a deterministic order.
* \param begin The beginning of the range.
* \param end The end of the range.
*/
template <typename IterType>
void AddFunctionsOrdered(IterType begin, IterType end) {
this->AddFunctionsOrdered(begin, end, [](auto f) { return f; });
}
/*!
* \brief Add mod to be linked with the generated module
* \param mod The module to be linked.
Expand Down Expand Up @@ -377,6 +397,22 @@ inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) {
#endif
}

template <typename IterType, typename ConvType>
void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc) {
std::vector<PrimFunc> funcs;
for (auto it = begin; it != end; ++it) {
funcs.push_back(pfunc(*it));
}
std::sort(funcs.begin(), funcs.end(), [](PrimFunc func_a, PrimFunc func_b) {
std::string name_a = func_a->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
std::string name_b = func_b->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
return name_a < name_b;
});
for (auto& f : funcs) {
AddFunction(f);
}
}

} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
Expand Down
10 changes: 5 additions & 5 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) {

cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false, false);

for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
cg->AddFunction(f);
}
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
return Downcast<PrimFunc>(kv.second);
});

const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
if (flibdevice_path != nullptr) {
Expand Down
4 changes: 1 addition & 3 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
// makes sense when we start to use multiple modules.
cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime);

for (const auto& f : funcs) {
cg->AddFunction(f);
}
cg->AddFunctionsOrdered(funcs.begin(), funcs.end());

if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func);
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,5 +818,32 @@ def do_atomic_add(A):
tvm.testing.assert_allclose(a.numpy(), ref, rtol=1e-5)


@tvm.testing.requires_llvm
def test_llvm_order_functions():
"""Check that functions in the LLVM module are ordered alphabetically."""

# Note: the order is alphabetical because that's a predictable ordering. Any predictable
# ordering will work fine, but if the ordering changes, this test will need to be updated.
def make_call_extern(caller, callee):
# Create a function:
# float32 caller(float32 v) { return callee(v); }
ib = tvm.tir.ir_builder.create()
v = tvm.te.var("v", dtype="float32")
t = tvm.tir.call_extern("float32", callee, v)
ib.emit(t)
return tvm.tir.PrimFunc([v], ib.get()).with_attr("global_symbol", caller)

# Create some functions in a random order.
functions = {
"Danny": make_call_extern("Danny", "Dave"),
"Sammy": make_call_extern("Sammy", "Eve"),
"Kirby": make_call_extern("Kirby", "Fred"),
}
mod = tvm.IRModule(functions=functions)
ir_text = tvm.build(mod, None, target="llvm").get_source("ll")
matches = re.findall(r"^define[^@]*@([a-zA-Z_][a-zA-Z0-9_]*)", ir_text, re.MULTILINE)
assert matches == sorted(matches)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit e7847ca

Please sign in to comment.