Skip to content

Commit

Permalink
Remove support for run-time linked-params from codegen (apache#11144)
Browse files Browse the repository at this point in the history
Linking parameters via a runtime lookup function no longer happens after
commit b5f1dab (PR#8509): "Tir constants integration into compilation
pipeline". Now, in cases where the runtime lookup would have happened in
the past, the parameters are embedded into TIR, removing the need for a
runtime lookup.

There is still plenty of code around that implemented the original runtime
lookup. This patch removes the unnecessary leftovers from TVM's codegen.
  • Loading branch information
Krzysztof Parzyszek authored and Sergey Shtin committed May 17, 2022
1 parent 4d62b2f commit 84e1aa8
Show file tree
Hide file tree
Showing 9 changed files with 2 additions and 244 deletions.
35 changes: 0 additions & 35 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,41 +40,6 @@
#include <vector>

namespace tvm {
/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

class IRModule;

Expand Down
10 changes: 0 additions & 10 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,16 +333,6 @@ constexpr const char* kNoAlias = "tir.noalias";
*/
constexpr const char* kIsEntryFunc = "tir.is_entry_func";

/*!
* \brief Parameters used in the module that should be linked by the codegen.
*
* Type: Map<String, LinkableParam>
*
* \note This should be present only on a function named
* tvm::target::packed_func::kLookupLinkedParam.
*/
constexpr const char* kLinkedParams = "tir.linked_params";

/*!
* \brief Mark the function as the global function called from the host.
*
Expand Down
7 changes: 0 additions & 7 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,6 @@ IRModule IRModule::FromText(const String& text, const String& source_path) {
return tvm::parser::ParseModule(source_path, text);
}

LinkedParam::LinkedParam(int64_t id, tvm::runtime::NDArray param) {
auto n = make_object<LinkedParamNode>();
n->id = id;
n->param = param;
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(IRModuleNode);

TVM_REGISTER_GLOBAL("ir.IRModule")
Expand Down
19 changes: 0 additions & 19 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,23 +331,8 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {

std::vector<PrimFunc> funcs;
std::string entry_func;
Map<String, LinkedParam> linked_params;
bool could_have_linked_params = mod->ShouldLinkParameters();

for (auto kv : mod->functions) {
if (could_have_linked_params &&
kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) {
// If `f` is the linked-params function, extract the parameters from the
// attribute dictionary, and skip the codegen.
auto attrs_dict = Downcast<Map<String, ObjectRef>>(kv.second->attrs->dict);
CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end())
<< "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!";

CHECK(linked_params.empty()) << "Multiple linked-param functions";
linked_params =
Downcast<Map<String, LinkedParam>>(attrs_dict[::tvm::tir::attr::kLinkedParams]);
continue;
}
if (!kv.second->IsInstance<PrimFuncNode>()) {
// (@jroesch): we relax constraints here, Relay functions will just be ignored.
DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey();
Expand All @@ -368,10 +353,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
cg->AddMainFunction(entry_func);
}

if (!linked_params.empty()) {
cg->LinkParameters(linked_params);
}

// Uncomment to get the LLVM module right out of codegen, before optimizations.
// std::cerr << "HexagonModule.0 {\n" << *cg->GetModulePtr() << "}\n";
std::unique_ptr<llvm::Module> module = cg->Finish();
Expand Down
74 changes: 1 addition & 73 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,76 +192,6 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
}
}

llvm::GlobalVariable* CodeGenLLVM::GetLinkedParamSymbol(const std::string& param_name,
llvm::ConstantArray* array) {
std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + param_name;
llvm::GlobalVariable* var = module_->getGlobalVariable(symbol_name, true /* AllowInternal */);
if (var == nullptr) {
CHECK(array != nullptr) << "Expect param symbol " << symbol_name
<< " to either be defined or for the array to be supplied";
var = new llvm::GlobalVariable(*module_, static_cast<llvm::Type*>(array->getType()), true,
llvm::GlobalValue::InternalLinkage, array, symbol_name);
}
return var;
}

void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
// It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
// but they are at a different layer in the compiler...
llvm::Type* t_int_p = t_int_->getPointerTo(GetGlobalAddressSpace());

// args, tcodes, num_args, ret_value, ret_tcode, resource_handle
std::vector<llvm::Type*> param_types{t_void_p_, t_int_p, t_int_, t_void_p_, t_int_p, t_void_p_};
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);

llvm::Function* function =
llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
::tvm::runtime::symbol::tvm_lookup_linked_param, module_.get());
function->setCallingConv(llvm::CallingConv::C);
function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);

llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
builder_->SetInsertPoint(entry);

llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace());
llvm::Value* sid =
builder_->CreateLoad(t_int64_, builder_->CreateBitCast(GetArg(function, 0), t_int64_p));

auto ret_tcode = builder_->CreateBitCast(GetArg(function, 4), t_int_p);
auto ret_value = builder_->CreateBitCast(GetArg(function, 3),
t_void_p_->getPointerTo(GetGlobalAddressSpace()));

llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);
llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1);

builder_->SetInsertPoint(default_block);
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), ret_tcode);
builder_->CreateRet(ConstInt32(kTvmErrorNoError));

// Add data to the global section.
for (auto kv : params) {
auto array = NDArrayToLLVMArray(ctx_, kv.second->param);
llvm::GlobalVariable* param_symbol = GetLinkedParamSymbol(kv.first, array);
auto dtype = tvm::runtime::DataType(kv.second->param->dtype);
size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment);
#if TVM_LLVM_VERSION >= 100
param_symbol->setAlignment(llvm::Align(align));
#else
param_symbol->setAlignment(align);
#endif
param_symbol->setInitializer(array);

llvm::BasicBlock* case_block =
llvm::BasicBlock::Create(*ctx_, "case_" + param_symbol->getName(), function);
switch_inst->addCase(
llvm::cast<llvm::ConstantInt>(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block);
builder_->SetInsertPoint(case_block);
builder_->CreateStore(builder_->CreatePointerCast(param_symbol, t_void_p_), ret_value);
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode);
builder_->CreateRet(ConstInt32(0));
}
}

std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
for (size_t i = 0; i < link_modules_.size(); ++i) {
Expand Down Expand Up @@ -1419,9 +1349,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
if (auto* ptr_op = op->op.as<OpNode>()) {
auto call_op = GetRef<Op>(ptr_op);
if (op->op.same_as(builtin_lookup_param_)) {
return GetLinkedParamSymbol(Downcast<StringImm>(op->args[0])->value, nullptr);
} else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
// call extern intrinsic
ICHECK_GE(op->args.size(), 1U);
auto global_symbol = Downcast<StringImm>(op->args[0]);
Expand Down
20 changes: 0 additions & 20 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
* \param mod The module to be linked.
*/
void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
/*!
* \brief Link parameters into the module so they don't need to be supplied at runtime.
* Parameters can be linked into the module so that the generated code is easier to use, or so
* that RAM space doesn't need to be allocated for them. This function adds the given parameters
* to the generated LLVM module.
* \param storage_id_offset Offset added to the index of each entry in params_by_sid to form the
* storage_id of that parameter. Storage ids for parameters are expected to be contiguous.
* \param params_by_sid Array of NDArray. Each entry is a parameter. The index of the array (added
* to sid_offset) is the storage_id of the param.
* \param param_names Array containing the name for each param in params_by_sid.
*/
void LinkParameters(const Map<String, LinkedParam> params);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
Expand Down Expand Up @@ -349,14 +337,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
llvm::ArrayRef<llvm::Type*> arg_types);
/*!
* \brief Lookup or create a GlobalVariable whose content is the data field of a DLTensor for a
* given linked_param() CallNode.
* \param param_name Parameter name (e.g. unmangled, from lookup_param node).
* \return the GlobalVariable indicated in the brief.
*/
llvm::GlobalVariable* GetLinkedParamSymbol(const ::std::string& param_name,
llvm::ConstantArray* array);
/*!
* \brief Get the number of elements in the given vector value.
* \param vec The value, must be of a vector type.
Expand Down
19 changes: 1 addition & 18 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,26 +221,12 @@ class LLVMModuleNode final : public runtime::ModuleNode {

std::vector<PrimFunc> funcs;
std::string entry_func;
Map<String, LinkedParam> linked_params;
bool found_linked_params = false;
bool could_have_linked_params = mod->ShouldLinkParameters();
relay::Runtime runtime =
mod->GetAttr<relay::Runtime>(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp"));
bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false));
bool target_c_runtime = runtime->name == "crt";

for (auto kv : mod->functions) {
if (could_have_linked_params &&
kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) {
Map<String, ObjectRef> attrs_dict =
Downcast<Map<String, ObjectRef>>(kv.second->attrs->dict);
CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end())
<< "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!";
linked_params =
Downcast<Map<String, LinkedParam>>(attrs_dict[::tvm::tir::attr::kLinkedParams]);
found_linked_params = true;
continue;
}
if (!kv.second->IsInstance<PrimFuncNode>()) {
// (@jroesch): we relax constraints here, Relay functions will just be ignored.
DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got "
Expand All @@ -257,7 +243,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
funcs.push_back(f);
}
// TODO(@jroesch): follow up on this condition.
// ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params));
// ICHECK(funcs.size() > 0);
// TODO(tqchen): remove the entry function behavior as it does not
// makes sense when we start to use multiple modules.
cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime);
Expand Down Expand Up @@ -308,9 +294,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {

cg->SetFastMathFlag(fmf);

if (found_linked_params) {
cg->LinkParameters(linked_params);
}
cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func);
Expand Down
1 change: 0 additions & 1 deletion src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, target->str(), devices);
cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
Map<String, LinkedParam> linked_params;
PrimFunc aot_executor_fn;

std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
Expand Down
61 changes: 0 additions & 61 deletions tests/python/unittest/test_target_codegen_hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,66 +115,5 @@ def test_llvm_options():
assert re.search("-hexagon-noopt", str(target))


@tvm.testing.requires_hexagon
def test_linked_params_codegen():
# A simple model (a single conv2d) to trigger parameter separation:
mod_lines = [
'#[version = "0.0.5"]',
"def @main(%input: Tensor[(1, 16, 16, 3), uint8], %weights: Tensor[(3, 3, 3, 3), uint8])"
" -> Tensor[(1, 14, 14, 3), uint8] {",
' nn.conv2d(%input, %weights, data_layout="NHWC", kernel_layout="HWIO", '
'kernel_size=[3, 3], out_dtype="uint8")',
"}",
]
mod = tvm.parser.fromtext("\n".join(mod_lines))
# Make the params be 81 x 'T':
params = {"weights": np.full([3, 3, 3, 3], fill_value=ord("T"), dtype=np.uint8)}

target = tvm.target.hexagon("v68", link_params=True)

with tvm.transform.PassContext(opt_level=3):
lib = tvm.relay.build(mod, target=target, params=params)
llvm_ir = lib.get_lib().get_source("ll")

# The definition of the parameter:
p0_def_re = r"@__tvm_param__p0 = internal constant \[81 x i8\] c\"T{81}\", align 128"
assert re.search(p0_def_re, llvm_ir)

# The body of the _lookup_linked_param function:
linked_param_re = r"(define.*@_lookup_linked_param\(.*\).* {[^}]*})"
linked_param_body = re.search(linked_param_re, llvm_ir, flags=re.MULTILINE)
assert linked_param_body and linked_param_body.groups()

# Reference to the parameter:
p0_use_re = r"\[81 x i8\]\* @__tvm_param__p0"
assert re.search(p0_use_re, linked_param_body.groups()[0])

"""
A snippet of actual LLVM IR containing the definition of the linked
parameter, and the the body of the _lookup_linked_param function.
@__tvm_param__p0 = internal constant [81 x i8] c"TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT", align 128
define dllexport i32 @_lookup_linked_param(i8* nocapture readonly %0, i32* nocapture readnone %1, i32 %2, i8* nocapture %3, i32* nocapture %4, i8* nocapture readnone %5) local_unnamed_addr #2 {
entry:
%6 = bitcast i8* %0 to i64*
%7 = load i64, i64* %6, align 8
%cond = icmp eq i64 %7, 1
br i1 %cond, label %case___tvm_param__p0, label %common.ret
common.ret: ; preds = %entry, %case___tvm_param__p0
%storemerge = phi i32 [ 3, %case___tvm_param__p0 ], [ 4, %entry ]
store i32 %storemerge, i32* %4, align 4
ret i32 0
case___tvm_param__p0: ; preds = %entry
%8 = bitcast i8* %3 to i8**
store i8* getelementptr inbounds ([81 x i8], [81 x i8]* @__tvm_param__p0, i32 0, i32 0), i8** %8, align 4
br label %common.ret
}
"""


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 84e1aa8

Please sign in to comment.