diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index d4b538a4bef0..af2f5d857293 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -483,6 +483,8 @@ def hexagon(cpu_ver="v66", **kwargs): error if invalid. Does not affect codegen. llvm_options : str or list of str (default: None) User defined compiler arguments. + link_params : bool (default: False) + Whether to link graph parameters into the LLVM module. """ # Some of the target parameters correspond to target kind attributes @@ -507,6 +509,7 @@ def hexagon(cpu_ver="v66", **kwargs): "hvx": 128, "sim_options": None, "llvm_options": None, + "link_params": False, } config.update(kwargs) @@ -615,12 +618,27 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument args = [s.replace("=", "@") for s in llvm_options.split()] return "--llvm-options=" + ",".join(args) + # TVM target attributes string + def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument + """ Create TVM target features string. """ + + features = { + "link_params": "link-params", + } + opts = "" + for k in config: + if k in features: + opts += " --" + features[k] + "=" + str(config[k]) + return opts + # Sim args os.environ["HEXAGON_SIM_ARGS"] = create_sim_options(cpu_ver, config) target_str = create_llvm_target(cpu_ver, config) llvm_str = create_llvm_options(cpu_ver, config) - args_list = target_str.split() + llvm_str.split() + tvm_str = create_tvm_options(cpu_ver, config) + + args_list = target_str.split() + llvm_str.split() + tvm_str.split() return Target(" ".join(["hexagon"] + args_list)) diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 26356a547990..e9eacc27fc72 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -706,12 +706,38 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { std::unique_ptr tm = GetLLVMTargetMachine(target); std::unique_ptr ctx(new llvm::LLVMContext()); std::unique_ptr cg(new CodeGenHexagon()); - cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); + + std::vector funcs; + Map linked_params; + bool could_have_linked_params = target->GetAttr("link-params").value_or(Bool(false)); + for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; + if (could_have_linked_params && + kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { + // If `f` is the linked-params function, extract the parameters from the + // attribute dictionary, and skip the codegen. + auto attrs_dict = Downcast>(kv.second->attrs->dict); + CHECK(attrs_dict.find(::tvm::tir::attr::kLinkedParams) != attrs_dict.end()) + << "no " << ::tvm::tir::attr::kLinkedParams << " attribute found!"; + + CHECK(linked_params.empty()) << "Multiple linked-param functions"; + linked_params = + Downcast>(attrs_dict[::tvm::tir::attr::kLinkedParams]); + continue; + } auto f = Downcast(kv.second); + funcs.emplace_back(f); + } + + cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); + for (const PrimFunc& f : funcs) { cg->AddFunction(f); } + if (!linked_params.empty()) { + cg->LinkParameters(linked_params); + } + // Uncomment to get the LLVM module right out of codegen, before optimizations. // std::cerr << "HexagonModule.0 {\n" << *cg->GetModulePtr() << "}\n"; std::unique_ptr module = cg->Finish(); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index ab8e6eaad157..d719386d204b 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -341,6 +341,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) .add_attr_option("mcpu") .add_attr_option("mtriple") .add_attr_option("system-lib") + .add_attr_option("link-params", Bool(false)) .add_attr_option>("llvm-options") .set_default_keys({"hexagon"}); diff --git a/tests/python/unittest/test_target_codegen_hexagon.py b/tests/python/unittest/test_target_codegen_hexagon.py index 6ffb2f4741e8..28901b35e75b 100644 --- a/tests/python/unittest/test_target_codegen_hexagon.py +++ b/tests/python/unittest/test_target_codegen_hexagon.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. +import numpy as np import os import re import tvm +import tvm.relay import tvm.contrib.hexagon as hexagon @@ -107,7 +109,71 @@ def test_alloc_vtcm(): assert "HexagonBackendFreeVTCM" in calls +def test_linked_params_codegen(): + if not check_prereq_and_setup(): + return + + # A simple model (a single conv2d) to trigger parameter separation: + mod_lines = [ + '#[version = "0.0.5"]', + "def @main(%input: Tensor[(1, 16, 16, 3), uint8], %weights: Tensor[(3, 3, 3, 3), uint8])" + " -> Tensor[(1, 14, 14, 3), uint8] {", + ' nn.conv2d(%input, %weights, data_layout="NHWC", kernel_layout="HWIO", ' + 'kernel_size=[3, 3], out_dtype="uint8")', + "}", + ] + mod = tvm.parser.fromtext("\n".join(mod_lines)) + # Make the params be 81 x 'T': + params = {"weights": np.full([3, 3, 3, 3], fill_value=ord("T"), dtype=np.uint8)} + + target = tvm.target.hexagon("v68", link_params=True) + + with tvm.transform.PassContext(opt_level=3): + lib = tvm.relay.build(mod, target=target, target_host=target, params=params) + llvm_ir = lib.get_lib().get_source("ll") + + # The definition of the parameter: + p0_def_re = r"@__tvm_param__p0 = internal constant \[81 x i8\] c\"T{81}\", align 128" + assert re.search(p0_def_re, llvm_ir) + + # The body of the _lookup_linked_param function: + linked_param_re = r"(define.*@_lookup_linked_param\(.*\).* {[^}]*})" + linked_param_body = re.search(linked_param_re, llvm_ir, flags=re.MULTILINE) + assert linked_param_body and linked_param_body.groups() + + # Reference to the parameter: + p0_use_re = r"\[81 x i8\]\* @__tvm_param__p0" + assert re.search(p0_use_re, linked_param_body.groups()[0]) + + """ + A snippet of actual LLVM IR containing the definition of the linked + parameter, and the the body of the _lookup_linked_param function. + + + @__tvm_param__p0 = internal constant [81 x i8] c"TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT", align 128 + + define dllexport i32 @_lookup_linked_param(i8* nocapture readonly %0, i32* nocapture readnone %1, i32 %2, i8* nocapture %3, i32* nocapture %4, i8* nocapture readnone %5) local_unnamed_addr #2 { + entry: + %6 = bitcast i8* %0 to i64* + %7 = load i64, i64* %6, align 8 + %cond = icmp eq i64 %7, 1 + br i1 %cond, label %case___tvm_param__p0, label %common.ret + + common.ret: ; preds = %entry, %case___tvm_param__p0 + %storemerge = phi i32 [ 3, %case___tvm_param__p0 ], [ 4, %entry ] + store i32 %storemerge, i32* %4, align 4 + ret i32 0 + + case___tvm_param__p0: ; preds = %entry + %8 = bitcast i8* %3 to i8** + store i8* getelementptr inbounds ([81 x i8], [81 x i8]* @__tvm_param__p0, i32 0, i32 0), i8** %8, align 4 + br label %common.ret + } + """ + + if __name__ == "__main__": test_basic() test_llvm_target_features() test_alloc_vtcm() + test_linked_params_codegen()