Skip to content

Commit

Permalink
[CRT] Create C-runtime-style metadata module for llvm builds (apache#…
Browse files Browse the repository at this point in the history
…7398)

* Create C-runtime-style metadata module for llvm builds.

* maybe address manupa's comment

* lint

* actually address manupa comments

* comment and rename

* git-clang-format

* pylint

* cpp warning

* try to fix apps/bundle_deploy

* black format

* build correct file

* Use save() for C++-runtime targeted artifacts.

* fix build_module LLVM metadata module conditions

* fix test comment

* black format

* further restrict CRT MetadataModule creation

* Fix test_link_params

* black format and address zhiics comments

* fix test_link_params, i think?
  • Loading branch information
areusch authored and trevor-m committed Mar 2, 2021
1 parent dbf757d commit 108bf66
Show file tree
Hide file tree
Showing 16 changed files with 498 additions and 173 deletions.
19 changes: 13 additions & 6 deletions apps/bundle_deploy/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ $(endif)

CRT_SRCS = $(shell find $(CRT_ROOT))

MODEL_OBJ = $(build_dir)/model_c/devc.o $(build_dir)/model_c/lib0.o $(build_dir)/model_c/lib1.o
TEST_MODEL_OBJ = $(build_dir)/test_model_c/devc.o $(build_dir)/test_model_c/lib0.o $(build_dir)/test_model_c/lib1.o

demo_dynamic: $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/bundle.so $(build_dir)/graph_cpp.json $(build_dir)/graph_c.json $(build_dir)/params_cpp.bin $(build_dir)/params_c.bin $(build_dir)/cat.bin
$(QUIET)TVM_NUM_THREADS=1 $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/graph_cpp.json $(build_dir)/params_cpp.bin $(build_dir)/cat.bin
$(QUIET)TVM_NUM_THREADS=1 $(build_dir)/demo_dynamic $(build_dir)/bundle_c.so $(build_dir)/graph_c.json $(build_dir)/params_c.bin $(build_dir)/cat.bin
Expand Down Expand Up @@ -93,11 +96,11 @@ $(build_dir)/test_dynamic: test.cc ${build_dir}/test_graph_c.json ${build_dir}/t
$(QUIET)mkdir -p $(@D)
$(QUIET)g++ $(PKG_CXXFLAGS) -o $@ test.cc $(BACKTRACE_OBJS) $(BACKTRACE_LDFLAGS)

$(build_dir)/demo_static: demo_static.c ${build_dir}/bundle_static.o ${build_dir}/model_c.o ${build_dir}/crt/libmemory.a ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a ${build_dir}/graph_c.json.c ${build_dir}/params_c.bin.c $(BACKTRACE_OBJS)
$(build_dir)/demo_static: demo_static.c ${build_dir}/bundle_static.o $(MODEL_OBJ) ${build_dir}/crt/libmemory.a ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a ${build_dir}/graph_c.json.c ${build_dir}/params_c.bin.c $(BACKTRACE_OBJS)
$(QUIET)mkdir -p $(@D)
$(QUIET)gcc $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS)

$(build_dir)/test_static: test_static.c ${build_dir}/bundle_static.o ${build_dir}/test_model_c.o ${build_dir}/crt/libmemory.a ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(build_dir)/test_static: test_static.c ${build_dir}/bundle_static.o $(TEST_MODEL_OBJ) ${build_dir}/crt/libmemory.a ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(QUIET)mkdir -p $(@D)
$(QUIET)gcc $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_LDFLAGS)

Expand All @@ -119,27 +122,31 @@ $(build_dir)/params_c.bin.c: $(build_dir)/params_c.bin
$(build_dir)/params_cpp.bin.c: $(build_dir)/params_cpp.bin
$(QUIET)xxd -i $^ > $@

$(build_dir)/model_c.o $(build_dir)/graph_c.json $(build_dir)/model_cpp.o $(build_dir)/graph_cpp.json $(build_dir)/params.bin $(build_dir)/cat.bin: build_model.py
$(MODEL_OBJ) $(build_dir)/graph_c.json $(build_dir)/model_cpp.o $(build_dir)/graph_cpp.json $(build_dir)/params.bin $(build_dir)/cat.bin: build_model.py
$(QUIET)python3 $< -o $(build_dir)
$(QUIET)mkdir -p build/model_c
$(QUIET)tar -C build/model_c -xvf build/model_c.tar

$(build_dir)/test_model_c.o $(build_dir)/test_graph_c.json $(build_dir)/test_params_c.bin $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(build_dir)/test_model_cpp.o $(build_dir)/test_graph_cpp.json $(build_dir)/test_params_cpp.bin $(build_dir)/test_data_cpp.bin $(build_dir)/test_output_cpp.bin: build_model.py
$(TEST_MODEL_OBJ) $(build_dir)/test_graph_c.json $(build_dir)/test_params_c.bin $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(build_dir)/test_model_cpp.o $(build_dir)/test_graph_cpp.json $(build_dir)/test_params_cpp.bin $(build_dir)/test_data_cpp.bin $(build_dir)/test_output_cpp.bin: build_model.py
$(QUIET)python3 $< -o $(build_dir) --test
$(QUIET)mkdir -p build/test_model_c
$(QUIET)tar -C build/test_model_c -xvf build/test_model_c.tar

# Build our bundle against the serialized bundle.c API, the runtime.cc API, and
# the serialized graph.json and params.bin
$(build_dir)/bundle.so: bundle.cc runtime.cc $(build_dir)/model_cpp.o
$(QUIET)mkdir -p $(@D)
$(QUIET)g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)

$(build_dir)/bundle_c.so: bundle.c $(build_dir)/model_c.o ${build_dir}/crt/libmemory.a ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(build_dir)/bundle_c.so: bundle.c $(MODEL_OBJ) ${build_dir}/crt/libmemory.a ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(QUIET)mkdir -p $(@D)
$(QUIET)gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS)

$(build_dir)/test_bundle.so: bundle.cc runtime.cc $(build_dir)/test_model_cpp.o
$(QUIET)mkdir -p $(@D)
$(QUIET)g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)

$(build_dir)/test_bundle_c.so: bundle.c $(build_dir)/test_model_c.o ${build_dir}/crt/libmemory.a ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(build_dir)/test_bundle_c.so: bundle.c $(TEST_MODEL_OBJ) ${build_dir}/crt/libmemory.a ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(QUIET)mkdir -p $(@D)
$(QUIET)gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS)

Expand Down
25 changes: 23 additions & 2 deletions apps/bundle_deploy/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm import te
import logging
import json
from tvm.contrib import cc as _cc

RUNTIMES = {
"c": "{name}_c.{ext}",
Expand Down Expand Up @@ -51,7 +52,17 @@ def build_module(opts):
build_dir = os.path.abspath(opts.out_dir)
if not os.path.isdir(build_dir):
os.makedirs(build_dir)
lib.save(os.path.join(build_dir, file_format_str.format(name="model", ext="o")))
ext = "tar" if runtime_name == "c" else "o"
lib_file_name = os.path.join(build_dir, file_format_str.format(name="model", ext=ext))
if runtime_name == "c":
lib.export_library(lib_file_name)
else:
# NOTE: at present, export_libarary will always create _another_ shared object, and you
# can't stably combine two shared objects together (in this case, init_array is not
# populated correctly when you do that). So for now, must continue to use save() with the
# C++ library.
# TODO(areusch): Obliterate runtime.cc and replace with libtvm_runtime.so.
lib.save(lib_file_name)
with open(
os.path.join(build_dir, file_format_str.format(name="graph", ext="json")), "w"
) as f_graph_json:
Expand Down Expand Up @@ -84,7 +95,17 @@ def build_test_module(opts):
build_dir = os.path.abspath(opts.out_dir)
if not os.path.isdir(build_dir):
os.makedirs(build_dir)
lib.save(os.path.join(build_dir, file_format_str.format(name="test_model", ext="o")))
ext = "tar" if runtime_name == "c" else "o"
lib_file_name = os.path.join(build_dir, file_format_str.format(name="test_model", ext=ext))
if runtime_name == "c":
lib.export_library(lib_file_name)
else:
# NOTE: at present, export_libarary will always create _another_ shared object, and you
# can't stably combine two shared objects together (in this case, init_array is not
# populated correctly when you do that). So for now, must continue to use save() with the
# C++ library.
# TODO(areusch): Obliterate runtime.cc and replace with libtvm_runtime.so.
lib.save(lib_file_name)
with open(
os.path.join(build_dir, file_format_str.format(name="test_graph", ext="json")), "w"
) as f_graph_json:
Expand Down
21 changes: 14 additions & 7 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,19 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
if not isinstance(target_host, Target):
target_host = Target(target_host)
if (
"system-lib" in target_host.attrs
and target_host.attrs["system-lib"].value == 1
and target_host.kind.name == "c"
target_host.attrs.get("runtime", tvm.runtime.String("c++")) == "c"
and target_host.attrs.get("system-lib", 0).value == 1
):
create_csource_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateCSourceMetadataModule"
)
return create_csource_metadata_module([rt_mod_host], target_host)
if target_host.kind.name == "c":
create_csource_crt_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateCSourceCrtMetadataModule"
)
return create_csource_crt_metadata_module([rt_mod_host], target_host)

if target_host.kind.name == "llvm":
create_llvm_crt_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateLLVMCrtMetadataModule"
)
return create_llvm_crt_metadata_module([rt_mod_host], target_host)

return rt_mod_host
6 changes: 5 additions & 1 deletion python/tvm/micro/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def _target_from_sources(cls, sources):
target_strs = set()

for obj in sources:
if os.path.splitext(obj)[1] not in (".cc", ".c"):
continue

with open(obj) as obj_f:
for line in obj_f:
m = cls.TVM_TARGET_RE.match(line)
Expand Down Expand Up @@ -246,7 +249,8 @@ def library(self, output, sources, options=None):
)

prefix = self._autodetect_toolchain_prefix(target)
outputs = []
outputs = [s for s in sources if os.path.splitext(s)[1] == ".o"]
sources = [s for s in sources if s not in outputs]
for src in sources:
src_base, src_ext = os.path.splitext(os.path.basename(src))

Expand Down
85 changes: 41 additions & 44 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) {
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
} else if (target_c_runtime_) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
registry_functions_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
}
AddDebugInformation(function_);
}
Expand Down Expand Up @@ -791,47 +785,50 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
return GetContextPtr(gv_tvm_parallel_barrier_);
}

void CodeGenCPU::AddStartupFunction() {
if (registry_functions_.size() != 0) {
ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime";
Array<String> symbols;
std::vector<llvm::Constant*> funcs;
for (auto sym : registry_functions_) {
symbols.push_back(sym.first);
funcs.emplace_back(llvm::ConstantExpr::getBitCast(
sym.second, ftype_tvm_backend_packed_c_func_->getPointerTo()));
}
llvm::DataLayout layout(module_.get());
llvm::ArrayType* t_tvm_crt_func_ptrs =
llvm::ArrayType::get(ftype_tvm_backend_packed_c_func_->getPointerTo(), funcs.size());
llvm::GlobalVariable* func_registry_ptrs = new llvm::GlobalVariable(
*module_, t_tvm_crt_func_ptrs, true, llvm::GlobalValue::InternalLinkage,
llvm::ConstantArray::get(t_tvm_crt_func_ptrs, funcs), "_tvm_func_registry_ptrs");
uint64_t align = layout.getTypeAllocSize(ftype_tvm_backend_packed_c_func_->getPointerTo());
void CodeGenCPU::DefineFunctionRegistry(Array<String> func_names) {
ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime";
Array<String> symbols;
std::vector<llvm::Constant*> funcs;
for (auto sym : func_names) {
symbols.push_back(sym);
llvm::GlobalVariable* sym_func = new llvm::GlobalVariable(
*module_, ftype_tvm_backend_packed_c_func_, true, llvm::GlobalValue::ExternalLinkage,
nullptr, sym.operator std::string());
funcs.emplace_back(sym_func);
}
llvm::DataLayout layout(module_.get());
llvm::ArrayType* t_tvm_crt_func_ptrs =
llvm::ArrayType::get(ftype_tvm_backend_packed_c_func_->getPointerTo(), funcs.size());
llvm::GlobalVariable* func_registry_ptrs = new llvm::GlobalVariable(
*module_, t_tvm_crt_func_ptrs, true, llvm::GlobalValue::InternalLinkage,
llvm::ConstantArray::get(t_tvm_crt_func_ptrs, funcs), "_tvm_func_registry_ptrs");
uint64_t align = layout.getTypeAllocSize(ftype_tvm_backend_packed_c_func_->getPointerTo());
#if TVM_LLVM_VERSION >= 100
func_registry_ptrs->setAlignment(llvm::Align(align));
func_registry_ptrs->setAlignment(llvm::Align(align));
#else
func_registry_ptrs->setAlignment(align);
func_registry_ptrs->setAlignment(align);
#endif
llvm::GlobalVariable* func_registry = new llvm::GlobalVariable(
*module_, t_tvm_crt_func_registry_, true, llvm::GlobalVariable::InternalLinkage,
llvm::ConstantStruct::get(
t_tvm_crt_func_registry_,
{GetConstString(::tvm::target::GenerateFuncRegistryNames(symbols)),
func_registry_ptrs}),
"_tvm_crt_func_registry");
llvm::GlobalVariable* module = new llvm::GlobalVariable(
*module_, t_tvm_crt_module_, true, llvm::GlobalValue::InternalLinkage,
llvm::ConstantStruct::get(t_tvm_crt_module_, {func_registry}), "_tvm_crt_module");

// Now build TVMSystemLibEntryPoint.
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false);
function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
"TVMSystemLibEntryPoint", module_.get());
llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(entry_point_entry);
builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_));
} else {
llvm::GlobalVariable* func_registry = new llvm::GlobalVariable(
*module_, t_tvm_crt_func_registry_, true, llvm::GlobalVariable::InternalLinkage,
llvm::ConstantStruct::get(
t_tvm_crt_func_registry_,
{GetConstString(::tvm::target::GenerateFuncRegistryNames(symbols)), func_registry_ptrs}),
"_tvm_crt_func_registry");
llvm::GlobalVariable* module = new llvm::GlobalVariable(
*module_, t_tvm_crt_module_, true, llvm::GlobalValue::InternalLinkage,
llvm::ConstantStruct::get(t_tvm_crt_module_, {func_registry}), "_tvm_crt_module");

// Now build TVMSystemLibEntryPoint.
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false);
function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
"TVMSystemLibEntryPoint", module_.get());
llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(entry_point_entry);
builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_));
}

void CodeGenCPU::AddStartupFunction() {
if (!target_c_runtime_) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
"__tvm_module_startup", module_.get());
Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg) override;

/*!
* \brief A CPU-specific function to create the FuncRegistry.
* \param func_names List of functions to be included, in order.
*/
void DefineFunctionRegistry(Array<String> func_names);

protected:
void AddStartupFunction() final;
// meta data
Expand Down
53 changes: 53 additions & 0 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "../../runtime/library_module.h"
#include "../func_registry_generator.h"
#include "codegen_blob.h"
#include "codegen_cpu.h"
#include "codegen_llvm.h"
#include "llvm_common.h"

Expand Down Expand Up @@ -445,6 +446,58 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
return runtime::Module(n);
});

runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& modules, Target target) {
Array<String> func_names;
for (runtime::Module mod : modules) {
auto pf_funcs = mod.GetFunction("get_func_names");
if (pf_funcs != nullptr) {
Array<String> func_names_ = pf_funcs();
for (const auto& fname : func_names_) {
func_names.push_back(fname);
}
}
}

InitializeLLVM();
auto tm = GetLLVMTargetMachine(target);
bool system_lib = target->GetAttr<Bool>("system-lib").value_or(Bool(false));
bool target_c_runtime = (target->GetAttr<String>("runtime").value_or("") == kTvmRuntimeCrt);
ICHECK(system_lib && target_c_runtime)
<< "For LLVM C-runtime metadata module, must include --system-lib and --runtime=c; "
<< "got target: " << target->str();
auto ctx = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenCPU> cg{new CodeGenCPU()};
cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime);

cg->DefineFunctionRegistry(func_names);
auto mod = cg->Finish();
mod->addModuleFlag(llvm::Module::Warning, "tvm_target",
llvm::MDString::get(*ctx, LLVMTargetToString(target)));
mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION);

if (tm->getTargetTriple().isOSDarwin()) {
mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
}

std::string verify_errors_storage;
llvm::raw_string_ostream verify_errors(verify_errors_storage);
LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors))
<< "LLVM module verification failed with the following errors: \n"
<< verify_errors.str();

auto n = make_object<LLVMModuleNode>();
n->Init(std::move(mod), ctx);
for (auto m : modules) {
n->Import(m);
}
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("runtime.CreateLLVMCrtMetadataModule")
.set_body_typed([](const Array<runtime::Module>& modules, Target target) {
return CreateLLVMCrtMetadataModule(modules, target);
});

} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
44 changes: 44 additions & 0 deletions src/target/llvm/llvm_module.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file llvm_module.h
* \brief Declares top-level shared functions related to the LLVM codegen.
*/

#ifndef TVM_TARGET_LLVM_LLVM_MODULE_H_
#define TVM_TARGET_LLVM_LLVM_MODULE_H_

#include <tvm/runtime/container.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target.h>

#ifdef TVM_LLVM_VERSION

namespace tvm {
namespace codegen {

runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& modules, Target target);

} // namespace codegen
} // namespace tvm

#endif // TVM_LLVM_VERSION

#endif // TVM_TARGET_LLVM_LLVM_MODULE_H_
Loading

0 comments on commit 108bf66

Please sign in to comment.