Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove support for run-time linked-params from codegen #11144

Merged
merged 1 commit into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,41 +40,6 @@
#include <vector>

namespace tvm {
/*!
* \brief Describes one parameter that should be linked into the generated module.
*
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
* use the information contained in this node to include the parameter data in the generated
* module.
*/
class LinkedParamNode : public Object {
public:
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
int64_t id;

/*! \brief Parameter data which should get linked into the final module. */
::tvm::runtime::NDArray param;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("id", &id);
v->Visit("param", &param);
}

static constexpr const char* _type_key = "tir.LinkedParam";
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
};

/*!
* \brief Managed reference to LinkedParamNode.
*/
class LinkedParam : public ObjectRef {
public:
TVM_DLL LinkedParam(int64_t id, tvm::runtime::NDArray param);

TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

class IRModule;

Expand Down
10 changes: 0 additions & 10 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,16 +333,6 @@ constexpr const char* kNoAlias = "tir.noalias";
*/
constexpr const char* kIsEntryFunc = "tir.is_entry_func";

/*!
* \brief Parameters used in the module that should be linked by the codegen.
*
* Type: Map<String, LinkableParam>
*
* \note This should be present only on a function named
* tvm::target::packed_func::kLookupLinkedParam.
*/
constexpr const char* kLinkedParams = "tir.linked_params";

/*!
* \brief Mark the function as the global function called from the host.
*
Expand Down
7 changes: 0 additions & 7 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,6 @@ IRModule IRModule::FromText(const String& text, const String& source_path) {
return tvm::parser::ParseModule(source_path, text);
}

LinkedParam::LinkedParam(int64_t id, tvm::runtime::NDArray param) {
auto n = make_object<LinkedParamNode>();
n->id = id;
n->param = param;
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(IRModuleNode);

TVM_REGISTER_GLOBAL("ir.IRModule")
Expand Down
19 changes: 0 additions & 19 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,23 +331,8 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {

std::vector<PrimFunc> funcs;
std::string entry_func;
Map<String, LinkedParam> linked_params;
bool could_have_linked_params = mod->ShouldLinkParameters();

for (auto kv : mod->functions) {
if (could_have_linked_params &&
kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) {
// If `f` is the linked-params function, extract the parameters from the
// attribute dictionary, and skip the codegen.
auto attrs_dict = Downcast<Map<String, ObjectRef>>(kv.second->attrs->dict);
CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end())
<< "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!";

CHECK(linked_params.empty()) << "Multiple linked-param functions";
linked_params =
Downcast<Map<String, LinkedParam>>(attrs_dict[::tvm::tir::attr::kLinkedParams]);
continue;
}
if (!kv.second->IsInstance<PrimFuncNode>()) {
// (@jroesch): we relax constraints here, Relay functions will just be ignored.
DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey();
Expand All @@ -368,10 +353,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
cg->AddMainFunction(entry_func);
}

if (!linked_params.empty()) {
cg->LinkParameters(linked_params);
}

// Uncomment to get the LLVM module right out of codegen, before optimizations.
// std::cerr << "HexagonModule.0 {\n" << *cg->GetModulePtr() << "}\n";
std::unique_ptr<llvm::Module> module = cg->Finish();
Expand Down
74 changes: 1 addition & 73 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,76 +192,6 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
}
}

llvm::GlobalVariable* CodeGenLLVM::GetLinkedParamSymbol(const std::string& param_name,
llvm::ConstantArray* array) {
std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + param_name;
llvm::GlobalVariable* var = module_->getGlobalVariable(symbol_name, true /* AllowInternal */);
if (var == nullptr) {
CHECK(array != nullptr) << "Expect param symbol " << symbol_name
<< " to either be defined or for the array to be supplied";
var = new llvm::GlobalVariable(*module_, static_cast<llvm::Type*>(array->getType()), true,
llvm::GlobalValue::InternalLinkage, array, symbol_name);
}
return var;
}

void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
// It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
// but they are at a different layer in the compiler...
llvm::Type* t_int_p = t_int_->getPointerTo(GetGlobalAddressSpace());

// args, tcodes, num_args, ret_value, ret_tcode, resource_handle
std::vector<llvm::Type*> param_types{t_void_p_, t_int_p, t_int_, t_void_p_, t_int_p, t_void_p_};
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);

llvm::Function* function =
llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
::tvm::runtime::symbol::tvm_lookup_linked_param, module_.get());
function->setCallingConv(llvm::CallingConv::C);
function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);

llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
builder_->SetInsertPoint(entry);

llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace());
llvm::Value* sid =
builder_->CreateLoad(t_int64_, builder_->CreateBitCast(GetArg(function, 0), t_int64_p));

auto ret_tcode = builder_->CreateBitCast(GetArg(function, 4), t_int_p);
auto ret_value = builder_->CreateBitCast(GetArg(function, 3),
t_void_p_->getPointerTo(GetGlobalAddressSpace()));

llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);
llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1);

builder_->SetInsertPoint(default_block);
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), ret_tcode);
builder_->CreateRet(ConstInt32(kTvmErrorNoError));

// Add data to the global section.
for (auto kv : params) {
auto array = NDArrayToLLVMArray(ctx_, kv.second->param);
llvm::GlobalVariable* param_symbol = GetLinkedParamSymbol(kv.first, array);
auto dtype = tvm::runtime::DataType(kv.second->param->dtype);
size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment);
#if TVM_LLVM_VERSION >= 100
param_symbol->setAlignment(llvm::Align(align));
#else
param_symbol->setAlignment(align);
#endif
param_symbol->setInitializer(array);

llvm::BasicBlock* case_block =
llvm::BasicBlock::Create(*ctx_, "case_" + param_symbol->getName(), function);
switch_inst->addCase(
llvm::cast<llvm::ConstantInt>(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block);
builder_->SetInsertPoint(case_block);
builder_->CreateStore(builder_->CreatePointerCast(param_symbol, t_void_p_), ret_value);
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode);
builder_->CreateRet(ConstInt32(0));
}
}

std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
for (size_t i = 0; i < link_modules_.size(); ++i) {
Expand Down Expand Up @@ -1419,9 +1349,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
if (auto* ptr_op = op->op.as<OpNode>()) {
auto call_op = GetRef<Op>(ptr_op);
if (op->op.same_as(builtin_lookup_param_)) {
return GetLinkedParamSymbol(Downcast<StringImm>(op->args[0])->value, nullptr);
} else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
// call extern intrinsic
ICHECK_GE(op->args.size(), 1U);
auto global_symbol = Downcast<StringImm>(op->args[0]);
Expand Down
20 changes: 0 additions & 20 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
* \param mod The module to be linked.
*/
void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
/*!
* \brief Link parameters into the module so they don't need to be supplied at runtime.
* Parameters can be linked into the module so that the generated code is easier to use, or so
* that RAM space doesn't need to be allocated for them. This function adds the given parameters
* to the generated LLVM module.
* \param storage_id_offset Offset added to the index of each entry in params_by_sid to form the
* storage_id of that parameter. Storage ids for parameters are expected to be contiguous.
* \param params_by_sid Array of NDArray. Each entry is a parameter. The index of the array (added
* to sid_offset) is the storage_id of the param.
* \param param_names Array containing the name for each param in params_by_sid.
*/
void LinkParameters(const Map<String, LinkedParam> params);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
Expand Down Expand Up @@ -349,14 +337,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type,
llvm::ArrayRef<llvm::Type*> arg_types);
/*!
* \brief Lookup or create a GlobalVariable whose content is the data field of a DLTensor for a
* given linked_param() CallNode.
* \param param_name Parameter name (e.g. unmangled, from lookup_param node).
* \return the GlobalVariable indicated in the brief.
*/
llvm::GlobalVariable* GetLinkedParamSymbol(const ::std::string& param_name,
llvm::ConstantArray* array);
/*!
* \brief Get the number of elements in the given vector value.
* \param vec The value, must be of a vector type.
Expand Down
19 changes: 1 addition & 18 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,26 +221,12 @@ class LLVMModuleNode final : public runtime::ModuleNode {

std::vector<PrimFunc> funcs;
std::string entry_func;
Map<String, LinkedParam> linked_params;
bool found_linked_params = false;
bool could_have_linked_params = mod->ShouldLinkParameters();
relay::Runtime runtime =
mod->GetAttr<relay::Runtime>(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp"));
bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false));
bool target_c_runtime = runtime->name == "crt";

for (auto kv : mod->functions) {
if (could_have_linked_params &&
kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) {
Map<String, ObjectRef> attrs_dict =
Downcast<Map<String, ObjectRef>>(kv.second->attrs->dict);
CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end())
<< "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!";
linked_params =
Downcast<Map<String, LinkedParam>>(attrs_dict[::tvm::tir::attr::kLinkedParams]);
found_linked_params = true;
continue;
}
if (!kv.second->IsInstance<PrimFuncNode>()) {
// (@jroesch): we relax constraints here, Relay functions will just be ignored.
DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got "
Expand All @@ -257,7 +243,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
funcs.push_back(f);
}
// TODO(@jroesch): follow up on this condition.
// ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params));
// ICHECK(funcs.size() > 0);
// TODO(tqchen): remove the entry function behavior as it does not
// makes sense when we start to use multiple modules.
cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime);
Expand Down Expand Up @@ -308,9 +294,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {

cg->SetFastMathFlag(fmf);

if (found_linked_params) {
cg->LinkParameters(linked_params);
}
cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func);
Expand Down
1 change: 0 additions & 1 deletion src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, target->str(), devices);
cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
Map<String, LinkedParam> linked_params;
PrimFunc aot_executor_fn;

std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
Expand Down
61 changes: 0 additions & 61 deletions tests/python/unittest/test_target_codegen_hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,66 +115,5 @@ def test_llvm_options():
assert re.search("-hexagon-noopt", str(target))


@tvm.testing.requires_hexagon
def test_linked_params_codegen():
# A simple model (a single conv2d) to trigger parameter separation:
mod_lines = [
'#[version = "0.0.5"]',
"def @main(%input: Tensor[(1, 16, 16, 3), uint8], %weights: Tensor[(3, 3, 3, 3), uint8])"
" -> Tensor[(1, 14, 14, 3), uint8] {",
' nn.conv2d(%input, %weights, data_layout="NHWC", kernel_layout="HWIO", '
'kernel_size=[3, 3], out_dtype="uint8")',
"}",
]
mod = tvm.parser.fromtext("\n".join(mod_lines))
# Make the params be 81 x 'T':
params = {"weights": np.full([3, 3, 3, 3], fill_value=ord("T"), dtype=np.uint8)}

target = tvm.target.hexagon("v68", link_params=True)

with tvm.transform.PassContext(opt_level=3):
lib = tvm.relay.build(mod, target=target, params=params)
llvm_ir = lib.get_lib().get_source("ll")

# The definition of the parameter:
p0_def_re = r"@__tvm_param__p0 = internal constant \[81 x i8\] c\"T{81}\", align 128"
assert re.search(p0_def_re, llvm_ir)

# The body of the _lookup_linked_param function:
linked_param_re = r"(define.*@_lookup_linked_param\(.*\).* {[^}]*})"
linked_param_body = re.search(linked_param_re, llvm_ir, flags=re.MULTILINE)
assert linked_param_body and linked_param_body.groups()

# Reference to the parameter:
p0_use_re = r"\[81 x i8\]\* @__tvm_param__p0"
assert re.search(p0_use_re, linked_param_body.groups()[0])

"""
A snippet of actual LLVM IR containing the definition of the linked
parameter, and the the body of the _lookup_linked_param function.


@__tvm_param__p0 = internal constant [81 x i8] c"TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT", align 128

define dllexport i32 @_lookup_linked_param(i8* nocapture readonly %0, i32* nocapture readnone %1, i32 %2, i8* nocapture %3, i32* nocapture %4, i8* nocapture readnone %5) local_unnamed_addr #2 {
entry:
%6 = bitcast i8* %0 to i64*
%7 = load i64, i64* %6, align 8
%cond = icmp eq i64 %7, 1
br i1 %cond, label %case___tvm_param__p0, label %common.ret

common.ret: ; preds = %entry, %case___tvm_param__p0
%storemerge = phi i32 [ 3, %case___tvm_param__p0 ], [ 4, %entry ]
store i32 %storemerge, i32* %4, align 4
ret i32 0

case___tvm_param__p0: ; preds = %entry
%8 = bitcast i8* %3 to i8**
store i8* getelementptr inbounds ([81 x i8], [81 x i8]* @__tvm_param__p0, i32 0, i32 0), i8** %8, align 4
br label %common.ret
}
"""


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))