Skip to content

Commit

Permalink
Create C-runtime-style metadata module for llvm builds.
Browse files Browse the repository at this point in the history
  • Loading branch information
areusch committed Feb 3, 2021
1 parent 44a071a commit 18a54ba
Show file tree
Hide file tree
Showing 14 changed files with 406 additions and 159 deletions.
21 changes: 12 additions & 9 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,13 +427,16 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi

if not isinstance(target_host, Target):
target_host = Target(target_host)
if (
"system-lib" in target_host.attrs
and target_host.attrs["system-lib"].value == 1
and target_host.kind.name == "c"
):
create_csource_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateCSourceMetadataModule"
)
return create_csource_metadata_module([rt_mod_host], target_host)
if "system-lib" in target_host.attrs and target_host.attrs["system-lib"].value == 1:
if target_host.kind.name == "c":
create_csource_crt_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateCSourceCrtMetadataModule"
)
return create_csource_crt_metadata_module([rt_mod_host], target_host)
elif target_host.kind.name == "llvm":
create_llvm_crt_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateLLVMCrtMetadataModule"
)
return create_llvm_crt_metadata_module([rt_mod_host], target_host)

return rt_mod_host
7 changes: 6 additions & 1 deletion python/tvm/micro/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _target_from_sources(cls, sources):
target_strs = set()

for obj in sources:
print("read", obj)
if os.path.splitext(obj)[1] not in (".cc", ".c"):
continue

with open(obj) as obj_f:
for line in obj_f:
m = cls.TVM_TARGET_RE.match(line)
Expand Down Expand Up @@ -247,7 +251,8 @@ def library(self, output, sources, options=None):
)

prefix = self._autodetect_toolchain_prefix(target)
outputs = []
outputs = [s for s in sources if os.path.splitext(s)[1] == ".o"]
sources = [s for s in sources if s not in outputs]
for src in sources:
src_base, src_ext = os.path.splitext(os.path.basename(src))

Expand Down
85 changes: 41 additions & 44 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) {
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
} else if (target_c_runtime_) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
registry_functions_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
}
AddDebugInformation(function_);
}
Expand Down Expand Up @@ -791,47 +785,50 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
return GetContextPtr(gv_tvm_parallel_barrier_);
}

void CodeGenCPU::AddStartupFunction() {
if (registry_functions_.size() != 0) {
ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime";
Array<String> symbols;
std::vector<llvm::Constant*> funcs;
for (auto sym : registry_functions_) {
symbols.push_back(sym.first);
funcs.emplace_back(llvm::ConstantExpr::getBitCast(
sym.second, ftype_tvm_backend_packed_c_func_->getPointerTo()));
}
llvm::DataLayout layout(module_.get());
llvm::ArrayType* t_tvm_crt_func_ptrs =
llvm::ArrayType::get(ftype_tvm_backend_packed_c_func_->getPointerTo(), funcs.size());
llvm::GlobalVariable* func_registry_ptrs = new llvm::GlobalVariable(
*module_, t_tvm_crt_func_ptrs, true, llvm::GlobalValue::InternalLinkage,
llvm::ConstantArray::get(t_tvm_crt_func_ptrs, funcs), "_tvm_func_registry_ptrs");
uint64_t align = layout.getTypeAllocSize(ftype_tvm_backend_packed_c_func_->getPointerTo());
void CodeGenCPU::DefineFunctionRegistry(Array<String> func_names) {
ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime";
Array<String> symbols;
std::vector<llvm::Constant*> funcs;
for (auto sym : func_names) {
symbols.push_back(sym);
llvm::GlobalVariable* sym_func = new llvm::GlobalVariable(
*module_, ftype_tvm_backend_packed_c_func_, true, llvm::GlobalValue::ExternalLinkage,
nullptr, sym.operator std::string());
funcs.emplace_back(sym_func);
}
llvm::DataLayout layout(module_.get());
llvm::ArrayType* t_tvm_crt_func_ptrs =
llvm::ArrayType::get(ftype_tvm_backend_packed_c_func_->getPointerTo(), funcs.size());
llvm::GlobalVariable* func_registry_ptrs = new llvm::GlobalVariable(
*module_, t_tvm_crt_func_ptrs, true, llvm::GlobalValue::InternalLinkage,
llvm::ConstantArray::get(t_tvm_crt_func_ptrs, funcs), "_tvm_func_registry_ptrs");
uint64_t align = layout.getTypeAllocSize(ftype_tvm_backend_packed_c_func_->getPointerTo());
#if TVM_LLVM_VERSION >= 100
func_registry_ptrs->setAlignment(llvm::Align(align));
func_registry_ptrs->setAlignment(llvm::Align(align));
#else
func_registry_ptrs->setAlignment(align);
func_registry_ptrs->setAlignment(align);
#endif
llvm::GlobalVariable* func_registry = new llvm::GlobalVariable(
*module_, t_tvm_crt_func_registry_, true, llvm::GlobalVariable::InternalLinkage,
llvm::ConstantStruct::get(
t_tvm_crt_func_registry_,
{GetConstString(::tvm::target::GenerateFuncRegistryNames(symbols)),
func_registry_ptrs}),
"_tvm_crt_func_registry");
llvm::GlobalVariable* module = new llvm::GlobalVariable(
*module_, t_tvm_crt_module_, true, llvm::GlobalValue::InternalLinkage,
llvm::ConstantStruct::get(t_tvm_crt_module_, {func_registry}), "_tvm_crt_module");

// Now build TVMSystemLibEntryPoint.
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false);
function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
"TVMSystemLibEntryPoint", module_.get());
llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(entry_point_entry);
builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_));
} else {
llvm::GlobalVariable* func_registry = new llvm::GlobalVariable(
*module_, t_tvm_crt_func_registry_, true, llvm::GlobalVariable::InternalLinkage,
llvm::ConstantStruct::get(
t_tvm_crt_func_registry_,
{GetConstString(::tvm::target::GenerateFuncRegistryNames(symbols)), func_registry_ptrs}),
"_tvm_crt_func_registry");
llvm::GlobalVariable* module = new llvm::GlobalVariable(
*module_, t_tvm_crt_module_, true, llvm::GlobalValue::InternalLinkage,
llvm::ConstantStruct::get(t_tvm_crt_module_, {func_registry}), "_tvm_crt_module");

// Now build TVMSystemLibEntryPoint.
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false);
function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
"TVMSystemLibEntryPoint", module_.get());
llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(entry_point_entry);
builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_));
}

void CodeGenCPU::AddStartupFunction() {
if (!target_c_runtime_) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
"__tvm_module_startup", module_.get());
Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg) override;

/*!
* \brief A CPU-specific function to create the FuncRegistry.
* \param func_names List of functions to be included, in order.
*/
void DefineFunctionRegistry(Array<String> func_names);

protected:
void AddStartupFunction() final;
// meta data
Expand Down
53 changes: 53 additions & 0 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "../../runtime/library_module.h"
#include "../func_registry_generator.h"
#include "codegen_blob.h"
#include "codegen_cpu.h"
#include "codegen_llvm.h"
#include "llvm_common.h"

Expand Down Expand Up @@ -445,6 +446,58 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
return runtime::Module(n);
});

runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& modules, Target target) {
Array<String> func_names;
for (runtime::Module mod : modules) {
auto pf_funcs = mod.GetFunction("get_func_names");
if (pf_funcs != nullptr) {
Array<String> func_names_ = pf_funcs();
for (const auto& fname : func_names_) {
func_names.push_back(fname);
}
}
}

InitializeLLVM();
auto tm = GetLLVMTargetMachine(target);
bool system_lib = target->GetAttr<Bool>("system-lib").value_or(Bool(false));
bool target_c_runtime = (target->GetAttr<String>("runtime").value_or("") == kTvmRuntimeCrt);
ICHECK(system_lib && target_c_runtime)
<< "For LLVM C-runtime metadata module, must include --system-lib and --runtime=c; "
<< "got target: " << target->str();
auto ctx = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenCPU> cg{new CodeGenCPU()};
cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime);

cg->DefineFunctionRegistry(func_names);
auto mod = cg->Finish();
mod->addModuleFlag(llvm::Module::Warning, "tvm_target",
llvm::MDString::get(*ctx, LLVMTargetToString(target)));
mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION);

if (tm->getTargetTriple().isOSDarwin()) {
mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
}

std::string verify_errors_storage;
llvm::raw_string_ostream verify_errors(verify_errors_storage);
LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors))
<< "LLVM module verification failed with the following errors: \n"
<< verify_errors.str();

auto n = make_object<LLVMModuleNode>();
n->Init(std::move(mod), ctx);
for (auto m : modules) {
n->Import(m);
}
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("runtime.CreateLLVMCrtMetadataModule")
.set_body_typed([](const Array<runtime::Module>& modules, Target target) {
return CreateLLVMCrtMetadataModule(modules, target);
});

} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
39 changes: 39 additions & 0 deletions src/target/llvm/llvm_module.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file llvm_module.h
* \brief Declares top-level shared functions related to the LLVM codegen.
*/

#include <tvm/runtime/container.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target.h>

#ifdef TVM_LLVM_VERSION

namespace tvm {
namespace codegen {

runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& modules, Target target);

} // namespace codegen
} // namespace tvm

#endif // TVM_LLVM_VERSION
108 changes: 108 additions & 0 deletions src/target/metadata_module.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file metadata_module.cc
* \brief Defines functions that build MetadataModules for C++ and C runtimes.
*/

#include "metadata_module.h"

#include "../runtime/meta_data.h"
#include "llvm/llvm_module.h"
#include "source/source_module.h"

namespace tvm {
namespace codegen {

/*!
* \brief Create a metadata module wrapper. The helper is used by different
* codegens, such as graph runtime codegen and the vm compiler.
*
* \param params The metadata for initialization of all modules.
* \param target_module the internal module that is compiled by tvm.
* \param ext_modules The external modules that needs to be imported inside the metadata
* module(s).
* \param target The target that all the modules are compiled for
* \return The created metadata module that manages initialization of metadata.
*/
runtime::Module CreateMetadataModule(
const std::unordered_map<std::string, runtime::NDArray>& params,
tvm::runtime::Module target_module, const Array<runtime::Module>& ext_modules, Target target) {
Array<tvm::runtime::Module> csource_modules;
Array<tvm::runtime::Module> binary_modules;

auto DSOExportable = [](tvm::runtime::Module& mod) {
return !std::strcmp(mod->type_key(), "llvm") || !std::strcmp(mod->type_key(), "c");
};

// Wrap all submodules in the initialization wrapper.
std::unordered_map<std::string, std::vector<std::string>> sym_metadata;
for (tvm::runtime::Module mod : ext_modules) {
auto pf_sym = mod.GetFunction("get_symbol");
auto pf_var = mod.GetFunction("get_const_vars");
std::vector<std::string> arrays;
if (pf_sym != nullptr && pf_var != nullptr) {
String symbol = pf_sym();
Array<String> variables = pf_var();
for (size_t i = 0; i < variables.size(); i++) {
arrays.push_back(variables[i].operator std::string());
}
ICHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " << symbol;
sym_metadata[symbol] = arrays;
}
// We only need loading of serialized constant data
// if there are constants present and required by the
// runtime module to be initialized by the binary
// metadata module. If not rest of the modules are
// wrapped in c-source metadata module.

// TODO(@manupa-arm) : we should be able to use csource_metadata
// if the variables are empty when all the runtime modules implement get_func_names
if (arrays.empty() && DSOExportable(mod) && target->kind->name == "c") {
csource_modules.push_back(mod);
} else {
binary_modules.push_back(mod);
}
}

if (target.defined() &&
target->GetAttr<String>("runtime").value_or(String("")) == kTvmRuntimeCrt) {
if (target->kind->name == "c") {
csource_modules.push_back(target_module);
target_module = CreateCSourceCrtMetadataModule(csource_modules, target);
} else if (target->kind->name == "llvm") {
binary_modules.push_back(target_module);
target_module = CreateLLVMCrtMetadataModule(binary_modules, target);
}
} else {
if (!binary_modules.empty()) {
runtime::Module binary_meta_mod = runtime::MetadataModuleCreate(params, sym_metadata);
binary_meta_mod.Import(target_module);
for (const auto& it : binary_modules) {
binary_meta_mod.Import(it);
}
return binary_meta_mod;
}
}
return target_module;
}

} // namespace codegen
} // namespace tvm
Loading

0 comments on commit 18a54ba

Please sign in to comment.