From c8eace9baee9f714fae139f244b44666a08e7d7b Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 10 Sep 2021 16:02:10 -0700 Subject: [PATCH 01/35] Initial investigation --- python/tvm/driver/build_module.py | 39 ++++++++++++------- python/tvm/target/codegen.py | 2 + python/tvm/target/target.py | 2 +- src/ir/module.cc | 3 +- src/target/codegen.cc | 16 ++++++++ .../unittest/test_runtime_heterogeneous.py | 2 + 6 files changed, 48 insertions(+), 16 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a7ebc00c315f..b9c7cfe780d7 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -302,23 +302,26 @@ def build( ---- See the note on :any:`tvm.target` on target string format. """ - if isinstance(inputs, schedule.Schedule): - if args is None: - raise ValueError("args must be given for build from schedule") + # TODO: After we move to C++, can we actually change the lower / build APIs? + + if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)): input_mod = lower(inputs, args, name=name, binds=binds) - elif isinstance(inputs, (list, tuple, container.Array)): - merged_mod = tvm.IRModule({}) - for x in inputs: - merged_mod.update(lower(x)) - input_mod = merged_mod - elif isinstance(inputs, (tvm.IRModule, PrimFunc)): - input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( - f"Inputs must be Schedule, IRModule or dict of target to IRModule, " + f"Inputs must be Schedule, PrimFunc, IRModule or dict of target to IRModule, " f"but got {type(inputs)}." ) - + + # More target maps here... is inputs ever a map? + # prepping and cutting module into chunks + + # 1. get into c++ + # 2. Remove everywhere that takes map + # after talking to xiyou he said a lot of difficulty was trying to maintain + # map correctly so I may just remove that. + if isinstance(inputs, (dict, container.Map)): + print("Inputs are: ", inputs) + assert False if not isinstance(inputs, (dict, container.Map)): target = Target.current() if target is None else target target = target if target else "llvm" @@ -332,23 +335,29 @@ def build( if not isinstance(mod, tvm.IRModule): raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") + # This is for backwards compatibility but uses a map unfortunately + """ target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host ) - + """ + if not target_host: for tar, mod in target_input_mod.items(): tar = Target(tar) device_type = ndarray.device(tar.kind.name, 0).device_type + # This seems broken if device_type == ndarray.cpu(0).device_type: target_host = tar break if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - + + """ target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host ) + """ mod_host_all = tvm.IRModule({}) @@ -359,8 +368,10 @@ def build( device_modules.append(mdev) # Generate a unified host module. + print("codegen build module") rt_mod_host = codegen.build_module(mod_host_all, target_host) + # To start, push everything up to here into C++, then deal with the rest. Not sure what the rest is doing TBH. # Import all modules. for mdev in device_modules: if mdev: diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 0ab4cb005cb4..e7549b3c0218 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -35,6 +35,8 @@ def build_module(mod, target): module : runtime.Module The corressponding module. """ + # this looks good! + print("In codegen build module") target = Target(target) if isinstance(target, str) else target return _ffi_api.Build(mod, target) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index af2f5d857293..403444419586 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -192,7 +192,7 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True): if target is None: assert host is None, "Target host is not empty when target is empty." return target, host - if isinstance(target, dict) and "kind" not in target: + if isinstance(target, dict) and "kind" not in target: # Well this is awful. new_target = {} for tgt, mod in target.items(): if not target_is_dict_key: diff --git a/src/ir/module.cc b/src/ir/module.cc index 15c441d61a23..3fd8baab7f63 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -521,7 +521,8 @@ TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, S TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "IRModule(" << node->functions << ")"; + p->stream << "IRModule(" << node->functions << ")" + << "attrs = " << node->attrs; }); } // namespace tvm diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 5a4aa39f01b4..fd465cbc5aec 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -40,6 +40,21 @@ namespace tvm { namespace codegen { +/* +runtime::Module BuildSchedule(Schedule sch, Target target) { + // call lower schedule to module + // call build?? +} + +runtime::Module BuildFunc(PrimFunc func, Target target) { + +} + +// Maybe don't need this here... +runtime::Module BuildTargetMap(Map, Target Target) { + +} +*/ runtime::Module Build(IRModule mod, Target target) { if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) @@ -311,6 +326,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, return (*codegen_f)(blob_byte_array, system_lib, target_triple); } +// Where build is registered TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build); // Export two auxiliary function to the runtime namespace. diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py index 167f61d748c2..0e67083eb88b 100644 --- a/tests/python/unittest/test_runtime_heterogeneous.py +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -399,6 +399,8 @@ def check_device(device, target_device): ) lower_add0.update(lower_add1) + # TODO: fix this + print("lower_add0 attrs: ", lower_add0.attr) target_flist = {target_device: lower_add0, target_host: lower_sub} target = tvm.target.Target(target, target_host) mhost = tvm.build(target_flist, target=target) From 8d6b2283e6bd753a3ac70e8a6078484823a395ab Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 14 Sep 2021 23:31:46 -0700 Subject: [PATCH 02/35] More progress! --- python/tvm/driver/build_module.py | 27 +++++++++++-------- python/tvm/target/target.py | 4 ++- src/ir/module.cc | 4 +-- .../unittest/test_runtime_heterogeneous.py | 1 - 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index b9c7cfe780d7..939038c3be6f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -302,10 +302,16 @@ def build( ---- See the note on :any:`tvm.target` on target string format. """ - # TODO: After we move to C++, can we actually change the lower / build APIs? if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)): input_mod = lower(inputs, args, name=name, binds=binds) + elif isinstance(inputs, (list, tuple, container.Array)): + merged_mod = tvm.IRModule({}) + for x in inputs: + merged_mod.update(lower(x)) + input_mod = merged_mod + elif isinstance(inputs, (tvm.IRModule, PrimFunc)): + input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( f"Inputs must be Schedule, PrimFunc, IRModule or dict of target to IRModule, " @@ -319,9 +325,7 @@ def build( # 2. Remove everywhere that takes map # after talking to xiyou he said a lot of difficulty was trying to maintain # map correctly so I may just remove that. - if isinstance(inputs, (dict, container.Map)): - print("Inputs are: ", inputs) - assert False + if not isinstance(inputs, (dict, container.Map)): target = Target.current() if target is None else target target = target if target else "llvm" @@ -329,19 +333,18 @@ def build( else: target_input_mod = inputs + # TODO: turn into a unified module + for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") - - # This is for backwards compatibility but uses a map unfortunately - """ + target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host ) - """ - + if not target_host: for tar, mod in target_input_mod.items(): tar = Target(tar) @@ -353,11 +356,13 @@ def build( if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - """ + target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host ) - """ + + # Turn into a map here + mod_host_all = tvm.IRModule({}) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 403444419586..9bac39418ead 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -174,6 +174,8 @@ def list_kinds(): """Returns the list of available target names.""" return list(_ffi_api.ListTargetKinds()) + + # TODO: make this return IRModule? @staticmethod def check_and_update_host_consist(target, host=None, target_is_dict_key=True): """A helper function that merges a legacy "target, target_host" pair, then returns @@ -192,7 +194,7 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True): if target is None: assert host is None, "Target host is not empty when target is empty." return target, host - if isinstance(target, dict) and "kind" not in target: # Well this is awful. + if isinstance(target, dict) and "kind" not in target: new_target = {} for tgt, mod in target.items(): if not target_is_dict_key: diff --git a/src/ir/module.cc b/src/ir/module.cc index 3fd8baab7f63..606991e59982 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -521,8 +521,8 @@ TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, S TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "IRModule(" << node->functions << ")" - << "attrs = " << node->attrs; + p->stream << "IRModule(" << node->functions << ")"; + // << "attrs = " << node->attrs; }); } // namespace tvm diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py index 0e67083eb88b..c1d1267d5ea1 100644 --- a/tests/python/unittest/test_runtime_heterogeneous.py +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -400,7 +400,6 @@ def check_device(device, target_device): lower_add0.update(lower_add1) # TODO: fix this - print("lower_add0 attrs: ", lower_add0.attr) target_flist = {target_device: lower_add0, target_host: lower_sub} target = tvm.target.Target(target, target_host) mhost = tvm.build(target_flist, target=target) From ba8836e0ac6dafb93cca4adf0b40206d4431c633 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Wed, 15 Sep 2021 13:13:17 -0700 Subject: [PATCH 03/35] More progress / notes --- python/tvm/driver/build_module.py | 29 +++++++++++++++++++---------- python/tvm/target/codegen.py | 2 +- python/tvm/target/target.py | 2 +- src/relay/backend/te_compiler.h | 17 +++++++++++++++++ 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 939038c3be6f..44889355d9a7 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -132,6 +132,7 @@ def lower( raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) +# TODO(@electriclilies): This should be moved into C++. def _build_for_device(input_mod, target, target_host): """Build the lowered functions for a device with the given compilation target. @@ -155,10 +156,14 @@ def _build_for_device(input_mod, target, target_host): mdev : tvm.module A module that contains device code. """ + # Ideally delete check_and_update_host_consist from here target, target_host = Target.check_and_update_host_consist(target, target_host) + # Point 1 device_type = ndarray.device(target.kind.name, 0).device_type mod_mixed = input_mod + # Do I need this? it is already assigned upstream supposedly.. + # get rid of it! delete as much as possible mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) opt_mixed = [ @@ -179,7 +184,7 @@ def _build_for_device(input_mod, target, target_host): tvm.tir.transform.SplitHostDevice(), ] mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed) - + # point 2 # device optimizations opt_device = tvm.transform.Sequential( [ @@ -220,7 +225,9 @@ def _build_for_device(input_mod, target, target_host): "Specified target %s, but cannot find device code, did you do " "bind?" % target ) + # rt_mod_dev is runtime::Module so this can be moved out maybe? rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None + # TIR module, runtime module return mod_host, rt_mod_dev @@ -303,7 +310,9 @@ def build( See the note on :any:`tvm.target` on target string format. """ + # Lowering if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)): + # should this be te_lower instead? input_mod = lower(inputs, args, name=name, binds=binds) elif isinstance(inputs, (list, tuple, container.Array)): merged_mod = tvm.IRModule({}) @@ -318,6 +327,7 @@ def build( f"but got {type(inputs)}." ) + # rest is codegen? # More target maps here... is inputs ever a map? # prepping and cutting module into chunks @@ -333,8 +343,6 @@ def build( else: target_input_mod = inputs - # TODO: turn into a unified module - for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") @@ -349,39 +357,40 @@ def build( for tar, mod in target_input_mod.items(): tar = Target(tar) device_type = ndarray.device(tar.kind.name, 0).device_type - # This seems broken if device_type == ndarray.cpu(0).device_type: target_host = tar break + # Why is this here? if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - + # why do we need to call chcek_and_update_host_consist again? target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host ) - # Turn into a map here - - mod_host_all = tvm.IRModule({}) + # This is building for target not device though.. + # From here through importing the device modules could probably be consolidated into one C++ function. device_modules = [] for tar, input_mod in target_input_mod.items(): + # mod_host is the module of the host.. bad name. + # Start with moving _build_for_device into c++ mod_host, mdev = _build_for_device(input_mod, tar, target_host) + # what are we updating here? mod_host_all.update(mod_host) device_modules.append(mdev) # Generate a unified host module. - print("codegen build module") rt_mod_host = codegen.build_module(mod_host_all, target_host) - # To start, push everything up to here into C++, then deal with the rest. Not sure what the rest is doing TBH. # Import all modules. for mdev in device_modules: if mdev: rt_mod_host.import_module(mdev) + # stop moving to C++ here. if not isinstance(target_host, Target): target_host = Target(target_host) if ( diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index e7549b3c0218..21d154a54279 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -35,8 +35,8 @@ def build_module(mod, target): module : runtime.Module The corressponding module. """ - # this looks good! print("In codegen build module") + # Where is Build defined? can't find it, can only find target.build.something... target = Target(target) if isinstance(target, str) else target return _ffi_api.Build(mod, target) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 9bac39418ead..e7e5e0277a8e 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -175,7 +175,7 @@ def list_kinds(): return list(_ffi_api.ListTargetKinds()) - # TODO: make this return IRModule? + # TODO: make this return IRModule? idk it seems @staticmethod def check_and_update_host_consist(target, host=None, target_is_dict_key=True): """A helper function that merges a legacy "target, target_host" pair, then returns diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 9d0eb1078ee0..c2c4dcafb353 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -201,6 +201,23 @@ IRModule LowerTE( transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn); + +/* +IRModule BuildForTargets(IRModule input_mod, Target target, Target target_host) { + +}*/ + +// TODO(@electriclilies): Rename me +// corresponds to point 1 thru point 2 in _build_for_device +IRModule build_for_device_mixed_mod(IRModule input_mod, Target target, Target target_host) { + IRModule mod_mixed = input_mod; + // TODO: put target as an attr on all the funcs + // mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) + + +} + + } // namespace tec } // namespace relay } // namespace tvm From b5bb9e8ca38e0fa6c65cd8b2f296a93fa8e2401e Mon Sep 17 00:00:00 2001 From: electriclilies Date: Thu, 16 Sep 2021 21:19:19 -0700 Subject: [PATCH 04/35] rewrite build_for_device mostly in c++ --- python/tvm/driver/build_module.py | 79 ++++----------------- python/tvm/ir/function.py | 1 + python/tvm/relay/backend/compile_engine.py | 2 + src/driver/driver_api.cc | 80 ++++++++++++++++++++++ src/relay/backend/aot_executor_codegen.cc | 1 + src/relay/backend/compile_engine.cc | 1 + src/relay/backend/te_compiler.h | 15 +--- src/relay/backend/te_compiler_cache.cc | 3 + src/relay/backend/utils.h | 1 + src/target/codegen.cc | 2 + tests/cpp/relay_build_module_test.cc | 3 + 11 files changed, 108 insertions(+), 80 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 44889355d9a7..4409c4ae1b07 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -157,78 +157,26 @@ def _build_for_device(input_mod, target, target_host): A module that contains device code. """ # Ideally delete check_and_update_host_consist from here - target, target_host = Target.check_and_update_host_consist(target, target_host) + # target, target_host = Target.check_and_update_host_consist(target, target_host) # Point 1 - device_type = ndarray.device(target.kind.name, 0).device_type - - mod_mixed = input_mod - # Do I need this? it is already assigned upstream supposedly.. - # get rid of it! delete as much as possible - mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - - opt_mixed = [ - tvm.tir.transform.VerifyMemory(), - tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), - ] - if len(mod_mixed.functions) == 1: - opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] - - if PassContext.current().config.get("tir.detect_global_barrier", False): - opt_mixed += [tvm.tir.transform.ThreadSync("global")] - opt_mixed += [ - tvm.tir.transform.ThreadSync("shared"), - tvm.tir.transform.ThreadSync("warp"), - tvm.tir.transform.InferFragment(), - tvm.tir.transform.LowerThreadAllreduce(), - tvm.tir.transform.MakePackedAPI(), - tvm.tir.transform.SplitHostDevice(), - ] - mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed) - # point 2 - # device optimizations - opt_device = tvm.transform.Sequential( - [ - tvm.tir.transform.Filter( - lambda f: "calling_conv" in f.attrs - and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH - ), - tvm.tir.transform.LowerWarpMemory(), - tvm.tir.transform.Simplify(), - tvm.tir.transform.LowerDeviceStorageAccessInfo(), - tvm.tir.transform.LowerCustomDatatypes(), - tvm.tir.transform.LowerIntrin(), - ] - ) - mod_dev = opt_device(mod_mixed) - - # host optimizations - opt_host = tvm.transform.Sequential( - [ - tvm.tir.transform.Filter( - lambda f: "calling_conv" not in f.attrs - or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH - ), - tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)), - tvm.tir.transform.LowerTVMBuiltin(), - tvm.tir.transform.LowerDeviceStorageAccessInfo(), - tvm.tir.transform.LowerCustomDatatypes(), - tvm.tir.transform.LowerIntrin(), - tvm.tir.transform.CombineContextCall(), - ] - ) - mod_host = opt_host(mod_mixed) + from tvm.driver import _ffi_api as _driver_ffi + mod_mixed = _driver_ffi.get_mod_mixed(input_mod) + device_mod = _driver_ffi.get_device_mod(mod_mixed) + host_mod = _driver_ffi.get_device_mod(mod_mixed) + device_type = ndarray.device(target.kind.name, 0).device_type if device_type == ndarray.cpu(0).device_type and target_host == target: - assert len(mod_dev.functions) == 0 - if "gpu" in target.keys and len(mod_dev.functions) == 0: + assert len(device_mod.functions) == 0 + if "gpu" in target.keys and len(device_mod.functions) == 0: warnings.warn( "Specified target %s, but cannot find device code, did you do " "bind?" % target ) # rt_mod_dev is runtime::Module so this can be moved out maybe? - rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None - # TIR module, runtime module - return mod_host, rt_mod_dev + rt_mod_dev = codegen.build_module(device_mod, target) if len(device_mod.functions) != 0 else None + # TIR module for the host, runtime module for devices? + return host_mod, rt_mod_dev + def build( @@ -319,8 +267,6 @@ def build( for x in inputs: merged_mod.update(lower(x)) input_mod = merged_mod - elif isinstance(inputs, (tvm.IRModule, PrimFunc)): - input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( f"Inputs must be Schedule, PrimFunc, IRModule or dict of target to IRModule, " @@ -414,6 +360,7 @@ def build( return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) +# What is OperatorModule and how is it different from runtime::Module class OperatorModule(Module): """Wraps the Module returned by tvm.build() and captures additional outputs of that function.""" diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index c3f1bf5f562a..c879935b5011 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -22,6 +22,7 @@ from . import _ffi_api +# Python CallingConv class CallingConv(IntEnum): """Possible kinds of calling conventions.""" diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index e9129db7b200..a2b7306eab3c 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -275,6 +275,8 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) return best_plevel_impl, outputs[best_plevel_impl] +# Returns LoweredOutput + @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bfea3e7b67c0..86af33b6d6e1 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -455,6 +455,8 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target // Can we make this take one annotated IRModule? // // Build for heterogeneous execution. +// +// It looks like this version of build doesn't lower, unlike the python version.... runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); @@ -530,12 +532,90 @@ runtime::Module build(const Map& inputs_arg, const Target& tar } // Build for homogeneous execution. +// Where is this called from? runtime::Module build(const IRModule& funcs, const Target& target_arg, const Target& target_host_arg) { auto target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); + // More maps of target and target host Map inputs = {{target, funcs}}; return build(inputs, target_host); } +// Gets the "mixed_module" from python driver/build_module.py's build function. +// Honestly not really sure what this actually is. +IRModule GetModMixed(IRModule mod) { + + transform::PassContext pass_ctx = transform::PassContext::Current(); + + Array pass_list; + pass_list.push_back(tir::transform::VerifyMemory()); + pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + + // Python annotates all functions in the mod with the Target passed in here; I think we shouldn't have to do that. + + bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); + if (detect_global_barrier) { + pass_list.push_back(tir::transform::ThreadSync("global")); + } + + pass_list.push_back(tir::transform::ThreadSync("shared")); + pass_list.push_back(tir::transform::ThreadSync("warp")); + pass_list.push_back(tir::transform::InferFragment()); + pass_list.push_back(tir::transform::LowerThreadAllreduce()); + pass_list.push_back(tir::transform::MakePackedAPI(-1)); // -1 is the default input passed in in the python version + pass_list.push_back(tir::transform::SplitHostDevice()); + + + return transform::Sequential(pass_list)(mod); + +} +TVM_REGISTER_GLOBAL("driver.get_mod_mixed").set_body_typed([](IRModule mod) { + return GetModMixed(mod); +}); + +IRModule GetDeviceMod(IRModule mixed_mod) { + Array pass_list; + auto check_calling_conv_func = [=] (tir::PrimFunc func) { + Optional calling_conv = func->GetAttr(tvm::attr::kCallingConv); + return (!calling_conv) || (calling_conv.value() != CallingConv::kDeviceKernelLaunch); + + }; + + pass_list.push_back(Filter(check_calling_conv_func)); // Filter by calling convention + pass_list.push_back(tir::transform::LowerWarpMemory()); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + pass_list.push_back(tir::transform::LowerCustomDatatypes()); + pass_list.push_back(tir::transform::LowerIntrin()); + + return transform::Sequential(pass_list)(mixed_mod); +} + +TVM_REGISTER_GLOBAL("driver.get_device_mod").set_body_typed([](IRModule mod) { + return GetDeviceMod(mod); +}); + +IRModule GetHostMod(IRModule mixed_mod) { + Array pass_list; + auto check_calling_conv_func = [=] (tir::PrimFunc func) { + Optional calling_conv = func->GetAttr(tvm::attr::kCallingConv); + return (!calling_conv) || (calling_conv.value() != CallingConv::kDeviceKernelLaunch); + + }; + + pass_list.push_back(Filter(check_calling_conv_func)); // Filter by calling convention + // Python version added target_host as an attribute to every function here + pass_list.push_back(tir::transform::LowerTVMBuiltin()); + pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + pass_list.push_back(tir::transform::LowerCustomDatatypes()); + pass_list.push_back(tir::transform::LowerIntrin()); + pass_list.push_back(tir::transform::CombineContextCall()); + + return transform::Sequential(pass_list)(mixed_mod); +} +TVM_REGISTER_GLOBAL("driver.get_host_mod").set_body_typed([](IRModule mod) { + return GetHostMod(mod); +}); + } // namespace tvm diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 70779ac58abf..f99e34a31e69 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -557,6 +557,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { target_host_(target_host), use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))) {} + // Returns LoweredOutput LoweredOutput Codegen(relay::Function func, String mod_name) { auto aot_allocator = AOTOnDemandAllocator(); aot_allocator.Run(func); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 6142e8323dea..04471bfa4d02 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -282,6 +282,7 @@ CompileEngine& CompileEngine::Global() { TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool); +// Make LoweredOutput TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") .set_body_typed([](tvm::Array outputs, OpImplementation impl) { return LoweredOutput(outputs, impl); diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index c2c4dcafb353..09fa18802f03 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -37,6 +37,7 @@ #include #include +#include #include #include #include @@ -202,20 +203,6 @@ transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn); -/* -IRModule BuildForTargets(IRModule input_mod, Target target, Target target_host) { - -}*/ - -// TODO(@electriclilies): Rename me -// corresponds to point 1 thru point 2 in _build_for_device -IRModule build_for_device_mixed_mod(IRModule input_mod, Target target, Target target_host) { - IRModule mod_mixed = input_mod; - // TODO: put target as an attr on all the funcs - // mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - - -} } // namespace tec diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d0e83765928a..d34f01d67d0f 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -54,6 +54,7 @@ TVM_REGISTER_NODE_TYPE(CachedFuncNode); TVM_REGISTER_NODE_TYPE(CCacheKeyNode); TVM_REGISTER_NODE_TYPE(CCacheValueNode); +// LoweredOutput constructor! LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { auto n = make_object(); n->outputs = std::move(outputs); @@ -224,6 +225,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator Array VisitExpr_(const CallNode* call_node) final { static auto fpattern = Op::GetAttrMap("TOpPattern"); + // So this is the PYTHON version not the C++ version defined in relay_build_module_test.cc static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -254,6 +256,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator const auto* copy_input = inputs[0].operator->(); outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { + // Right so we need to change lower_call LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); outputs = lowered_out->outputs; impl = lowered_out->implementation; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index cf8a2dd4b8e0..d07ae94bc37a 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -133,6 +133,7 @@ class FunctionInfo : public ObjectRef { */ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); +// LoweredOutput definition -- why is this redefined in python??? /*! * \brief Executor generator artifacts. Those artifacts are subsequently * used by the relay build process. diff --git a/src/target/codegen.cc b/src/target/codegen.cc index fd465cbc5aec..ff019691d252 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -55,6 +55,8 @@ runtime::Module BuildTargetMap(Map, Target Target) { } */ +// Leave this -- why not called codegen? +// TODO(@electriclilies): Rename this to Codegen (dont get consensus just try to put it in) runtime::Module Build(IRModule mod, Target target) { if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index ebb2867e7b69..8aa603d1a818 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -66,6 +66,9 @@ TVM_REGISTER_GLOBAL("relay.backend.lower_call") OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); auto impl = strategy->specializations[0]->implementations[0]; auto outs = impl.Compute(call->attrs, inputs, out_type); + // Using make_LoweredOutput here + // wait ok is this the python LoweredOutput or the C++ LoweredOutput? + // This is a test. so its probably ok auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput"); if (!f) { LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; From f9372db99fb309e4e48bd07690e943f5e65d77d1 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 17 Sep 2021 08:43:35 -0700 Subject: [PATCH 05/35] More progress --- include/tvm/runtime/module.h | 2 ++ python/tvm/relay/build_module.py | 2 ++ src/driver/driver_api.cc | 1 + src/relay/backend/aot_executor_codegen.cc | 1 + src/relay/backend/build_module.cc | 4 ++++ src/relay/backend/te_compiler_cache.cc | 1 + src/runtime/module.cc | 2 ++ 7 files changed, 13 insertions(+) diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 71be8d218d2d..8560f7b399c7 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -134,6 +134,7 @@ class TVM_DLL ModuleNode : public Object { * If the function need resource from the module(e.g. late linking), * it should capture sptr_to_self. */ + // This is pure virtual which means its only instantiated by subclasses of ModuleNode. virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) = 0; /*! @@ -240,6 +241,7 @@ constexpr const char* tvm_entrypoint_suffix = "run"; inline void Module::Import(Module other) { return (*this)->Import(other); } +// seems questionable to provide a mutable pointer into the runtime module? inline ModuleNode* Module::operator->() { return static_cast(get_mutable()); } inline const ModuleNode* Module::operator->() const { diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index c67ac1dc423d..fb9a8b78127f 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -93,6 +93,8 @@ class BuildModule(object): """ def __init__(self): + # This is implicitly calling GetFunction in RelayBuildModuleNode which then calls the correct + # executor's version of that function. self.mod = _build_module._BuildModule() self._get_graph_json = self.mod["get_graph_json"] self._get_module = self.mod["get_module"] diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 86af33b6d6e1..0fa7aacfe817 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -465,6 +465,7 @@ runtime::Module build(const Map& inputs_arg, const Target& tar Target target_host = target_host_arg; // Fetch previous defined target host in targets + // this is redefined in python ahh CheckAndUpdateHostConsistency(&inputs, &target_host); if (!target_host.defined()) { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index f99e34a31e69..e28ebb7ce217 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -730,6 +730,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { *rv = get_param_id(key); }); } else if (name == "get_irmodule") { + // OK here is one get_irmodule! return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); } else if (name == "get_external_modules") { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 69dced36295e..2efbf2bdef4c 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -176,6 +176,7 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { + // OH this must be where the self.build = mod["build"] comes from! return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 5); this->Build(args[0], args[1], args[2], args[3], args[4]); @@ -194,6 +195,7 @@ class RelayBuildModule : public runtime::ModuleNode { } }); } else if (name == "get_irmodule") { + // GetIRModule just calls the GetFunction of the executor, ends up in a DIFFERENT ModuleNode's GetFunction (oof) return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->executor_codegen_->GetIRModule(); }); @@ -488,9 +490,11 @@ class RelayBuildModule : public runtime::ModuleNode { executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); + // Another Map auto lowered_funcs = executor_codegen_->GetIRModule(); // No need to build for external functions. + // TODO(Ext_dev shouldn't be passed in in this module I think so eventually so we can nuke it) Target ext_dev("ext_dev"); if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) { lowered_funcs.Set(ext_dev, IRModule()); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d34f01d67d0f..53803b7cf9e2 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -443,6 +443,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; + // Huh why are we lowering the schedule here?? Seems weird. IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, diff --git a/src/runtime/module.cc b/src/runtime/module.cc index f9c281ab9d02..31b65ade6f43 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -63,10 +63,12 @@ void ModuleNode::Import(Module other) { PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) { ModuleNode* self = this; + // This must be the REAL GetFunction but IDK where it actually lives!! PackedFunc pf = self->GetFunction(name, GetObjectPtr(this)); if (pf != nullptr) return pf; if (query_imports) { for (Module& m : self->imports_) { + // Where is this defined? pf = m.operator->()->GetFunction(name, query_imports); if (pf != nullptr) { return pf; From 322f9f1eec2f77195ef9d76ad05ac25286fcb860 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Mon, 20 Sep 2021 13:56:44 +0300 Subject: [PATCH 06/35] Initial split of transformations applied to device and host as post split action from mixed module --- include/tvm/driver/driver_api.h | 16 +++++ src/driver/driver_api.cc | 114 +++++++++++++++++--------------- 2 files changed, 77 insertions(+), 53 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 418d532fdd5f..00b624d8f65d 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -44,6 +44,22 @@ namespace tvm { +/*! + * \brief Returns the optimized IRModule for the device Target after device/host from mixed module. + * \param mixed_mod The optimized mixed module. + * \param target The device Target. + * \return The result optimized device module. + */ +IRModule GetDeviceOptimizedModule(IRModule mixed_mod, Target target); + +/*! + * \brief Returns the optimized IRModule for the host Target after device/host from mixed module. + * \param mixed_mod The optimized mixed module. + * \param target The host Target. + * \return The result optimized host module. + */ +IRModule GetOptimizedHostModule(IRModule mixed_mod, Target target_host); + /*! * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) * \param mod The IRmodule to lower diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0fa7aacfe817..369759fb3dfc 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -373,6 +373,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); }); +// Splits module into one to run on the device and one to run the host. E.g., CUDA, OpenCL etc std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg, const transform::PassContext& pass_ctx) { @@ -381,6 +382,8 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target Array mixed_pass_list = {BindTarget(target), tir::transform::VerifyMemory()}; + printf("calling split for device ********** \n"); + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); @@ -401,61 +404,27 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target auto opt_mixed = transform::Sequential(mixed_pass_list); mod_mixed = opt_mixed(std::move(mod_mixed)); - auto host_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != - CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target_host), - tir::transform::LowerTVMBuiltin(), - tir::transform::LowerCustomDatatypes(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - tir::transform::CombineContextCall(), - }; - auto opt_host = transform::Sequential(host_pass_list); ICHECK(mod_mixed.defined()) << "This module must be defined"; - auto mhost = opt_host(mod_mixed); - // device pipeline - auto device_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target), - tir::transform::LowerWarpMemory(), - tir::transform::Simplify(), - tir::transform::LowerCustomDatatypes(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - }; - auto opt_device = transform::Sequential(device_pass_list); - auto mdevice = opt_device(mod_mixed); + auto host_mod = GetOptimizedHostModule(mod_mixed, target_host); + + auto device_mod = GetDeviceOptimizedModule(mod_mixed, target); // some final misc checks. auto keys = target->GetKeys(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && mdevice->functions.size() == 0) { + if (target_is_gpu && device_mod->functions.size() == 0) { LOG(WARNING) << "Specified target " << target->str() << " but cannot find device code. Did you forget to bind?"; } - if (target->kind->device_type == kDLCPU && target_host == target) { - // TODO(@jroesch): This check is no longer true we need to figure out if we care about this. - // We need to relax this check for just TIR functions. - // ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " - // << "and host_target are both llvm target." - // << "\n"; - } - - return {mhost, mdevice}; + return {host_mod, device_mod}; } // Can we make this take one annotated IRModule? // // Build for heterogeneous execution. -// +// // It looks like this version of build doesn't lower, unlike the python version.... runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); @@ -533,7 +502,8 @@ runtime::Module build(const Map& inputs_arg, const Target& tar } // Build for homogeneous execution. -// Where is this called from? +// Where is this called from?] +// called from compile engine and it accepts lowered functions runtime::Module build(const IRModule& funcs, const Target& target_arg, const Target& target_host_arg) { auto target = target_arg, target_host = target_host_arg; @@ -546,16 +516,17 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, // Gets the "mixed_module" from python driver/build_module.py's build function. // Honestly not really sure what this actually is. IRModule GetModMixed(IRModule mod) { - transform::PassContext pass_ctx = transform::PassContext::Current(); Array pass_list; pass_list.push_back(tir::transform::VerifyMemory()); pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - // Python annotates all functions in the mod with the Target passed in here; I think we shouldn't have to do that. + // Python annotates all functions in the mod with the Target passed in here; I think we shouldn't + // have to do that. - bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); + bool detect_global_barrier = + pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); if (detect_global_barrier) { pass_list.push_back(tir::transform::ThreadSync("global")); } @@ -564,26 +535,25 @@ IRModule GetModMixed(IRModule mod) { pass_list.push_back(tir::transform::ThreadSync("warp")); pass_list.push_back(tir::transform::InferFragment()); pass_list.push_back(tir::transform::LowerThreadAllreduce()); - pass_list.push_back(tir::transform::MakePackedAPI(-1)); // -1 is the default input passed in in the python version + pass_list.push_back(tir::transform::MakePackedAPI( + -1)); // -1 is the default input passed in in the python version pass_list.push_back(tir::transform::SplitHostDevice()); - return transform::Sequential(pass_list)(mod); - } + TVM_REGISTER_GLOBAL("driver.get_mod_mixed").set_body_typed([](IRModule mod) { return GetModMixed(mod); }); IRModule GetDeviceMod(IRModule mixed_mod) { Array pass_list; - auto check_calling_conv_func = [=] (tir::PrimFunc func) { + auto check_calling_conv_func = [=](tir::PrimFunc func) { Optional calling_conv = func->GetAttr(tvm::attr::kCallingConv); return (!calling_conv) || (calling_conv.value() != CallingConv::kDeviceKernelLaunch); - }; - pass_list.push_back(Filter(check_calling_conv_func)); // Filter by calling convention + pass_list.push_back(Filter(check_calling_conv_func)); // Filter by calling convention pass_list.push_back(tir::transform::LowerWarpMemory()); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); @@ -599,13 +569,12 @@ TVM_REGISTER_GLOBAL("driver.get_device_mod").set_body_typed([](IRModule mod) { IRModule GetHostMod(IRModule mixed_mod) { Array pass_list; - auto check_calling_conv_func = [=] (tir::PrimFunc func) { + auto check_calling_conv_func = [=](tir::PrimFunc func) { Optional calling_conv = func->GetAttr(tvm::attr::kCallingConv); return (!calling_conv) || (calling_conv.value() != CallingConv::kDeviceKernelLaunch); - }; - pass_list.push_back(Filter(check_calling_conv_func)); // Filter by calling convention + pass_list.push_back(Filter(check_calling_conv_func)); // Filter by calling convention // Python version added target_host as an attribute to every function here pass_list.push_back(tir::transform::LowerTVMBuiltin()); pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); @@ -619,4 +588,43 @@ TVM_REGISTER_GLOBAL("driver.get_host_mod").set_body_typed([](IRModule mod) { return GetHostMod(mod); }); +IRModule GetOptimizedHostModule(IRModule mixed_mod, Target target_host) { + auto host_pass_list = { + Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + CallingConv::kDeviceKernelLaunch; + }), + BindTarget(target_host), + tir::transform::LowerTVMBuiltin(), + tir::transform::LowerCustomDatatypes(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), + tir::transform::CombineContextCall(), + }; + auto host_module = transform::Sequential(host_pass_list); + ICHECK(mixed_mod.defined()) << "This module must be defined"; + + return host_module(mixed_mod); +} + +IRModule GetDeviceOptimizedModule(IRModule mixed_mod, Target target) { + // device pipeline + auto device_pass_list = { + Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDeviceKernelLaunch; + }), + BindTarget(target), + tir::transform::LowerWarpMemory(), + tir::transform::Simplify(), + tir::transform::LowerCustomDatatypes(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), + }; + auto device_opt_mod = transform::Sequential(device_pass_list); + auto mdevice = device_opt_mod(mixed_mod); + + return device_opt_mod(mixed_mod); +} + } // namespace tvm From 0c4a01d53d9e635905605518616579dbea507ec9 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Mon, 20 Sep 2021 15:04:43 +0300 Subject: [PATCH 07/35] Combine duplicate passes after spliting mod on aot and vm flows --- include/tvm/driver/driver_api.h | 11 ++- src/driver/driver_api.cc | 170 +++++++++++++------------------- 2 files changed, 76 insertions(+), 105 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 00b624d8f65d..1ac3e703695e 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -44,13 +44,20 @@ namespace tvm { +/*! + * \brief Returns the optimized IRModule for original fused module (pre split) that contains device + * and host code. \param mixed_mod The original mixed module. \param target The device Target. + * \return The result optimized mixed module. + */ +IRModule OptimizeMixedModule(IRModule mixed_mod, Target target); + /*! * \brief Returns the optimized IRModule for the device Target after device/host from mixed module. * \param mixed_mod The optimized mixed module. * \param target The device Target. * \return The result optimized device module. */ -IRModule GetDeviceOptimizedModule(IRModule mixed_mod, Target target); +IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target); /*! * \brief Returns the optimized IRModule for the host Target after device/host from mixed module. @@ -58,7 +65,7 @@ IRModule GetDeviceOptimizedModule(IRModule mixed_mod, Target target); * \param target The host Target. * \return The result optimized host module. */ -IRModule GetOptimizedHostModule(IRModule mixed_mod, Target target_host); +IRModule OptimizeHostModule(IRModule mixed_mod, Target target_host); /*! * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 369759fb3dfc..991c66496136 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -379,36 +379,14 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target const transform::PassContext& pass_ctx) { Target target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); - Array mixed_pass_list = {BindTarget(target), - tir::transform::VerifyMemory()}; - - printf("calling split for device ********** \n"); - - mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { - mixed_pass_list.push_back(tir::transform::ThreadSync("global")); - } - mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); - mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); - mixed_pass_list.push_back(tir::transform::InferFragment()); - mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - - if (target->GetAttr("unpacked-api").value_or(Bool(false))) { - mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); - } else { - mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); - } - - mixed_pass_list.push_back(tir::transform::SplitHostDevice()); - - auto opt_mixed = transform::Sequential(mixed_pass_list); - mod_mixed = opt_mixed(std::move(mod_mixed)); ICHECK(mod_mixed.defined()) << "This module must be defined"; - auto host_mod = GetOptimizedHostModule(mod_mixed, target_host); + mod_mixed = OptimizeMixedModule(mod_mixed, target); - auto device_mod = GetDeviceOptimizedModule(mod_mixed, target); + auto host_mod = OptimizeHostModule(mod_mixed, target_host); + + auto device_mod = OptimizeDeviceModule(mod_mixed, target); // some final misc checks. auto keys = target->GetKeys(); @@ -515,12 +493,16 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, // Gets the "mixed_module" from python driver/build_module.py's build function. // Honestly not really sure what this actually is. -IRModule GetModMixed(IRModule mod) { +IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - Array pass_list; - pass_list.push_back(tir::transform::VerifyMemory()); - pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + Array mixed_pass_list; + if (target.defined()) { + mixed_pass_list.push_back(BindTarget(target)); + } + + mixed_pass_list.push_back(tir::transform::VerifyMemory()); + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); // Python annotates all functions in the mod with the Target passed in here; I think we shouldn't // have to do that. @@ -528,100 +510,82 @@ IRModule GetModMixed(IRModule mod) { bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); if (detect_global_barrier) { - pass_list.push_back(tir::transform::ThreadSync("global")); + mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } - pass_list.push_back(tir::transform::ThreadSync("shared")); - pass_list.push_back(tir::transform::ThreadSync("warp")); - pass_list.push_back(tir::transform::InferFragment()); - pass_list.push_back(tir::transform::LowerThreadAllreduce()); - pass_list.push_back(tir::transform::MakePackedAPI( - -1)); // -1 is the default input passed in in the python version - pass_list.push_back(tir::transform::SplitHostDevice()); + mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); + mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); + mixed_pass_list.push_back(tir::transform::InferFragment()); + mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - return transform::Sequential(pass_list)(mod); + if (target->GetAttr("unpacked-api").value_or(Bool(false))) { + mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); + } else { + mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); + } + mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + + auto opt_mixed = transform::Sequential(mixed_pass_list); + return opt_mixed(std::move(mixed_mod)); } TVM_REGISTER_GLOBAL("driver.get_mod_mixed").set_body_typed([](IRModule mod) { - return GetModMixed(mod); + Target empty_target; + return OptimizeMixedModule(mod, empty_target); }); -IRModule GetDeviceMod(IRModule mixed_mod) { - Array pass_list; - auto check_calling_conv_func = [=](tir::PrimFunc func) { - Optional calling_conv = func->GetAttr(tvm::attr::kCallingConv); - return (!calling_conv) || (calling_conv.value() != CallingConv::kDeviceKernelLaunch); - }; - - pass_list.push_back(Filter(check_calling_conv_func)); // Filter by calling convention - pass_list.push_back(tir::transform::LowerWarpMemory()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); - pass_list.push_back(tir::transform::LowerCustomDatatypes()); - pass_list.push_back(tir::transform::LowerIntrin()); - - return transform::Sequential(pass_list)(mixed_mod); -} - TVM_REGISTER_GLOBAL("driver.get_device_mod").set_body_typed([](IRModule mod) { - return GetDeviceMod(mod); + Target empty_target; + return OptimizeDeviceModule(mod, empty_target); }); -IRModule GetHostMod(IRModule mixed_mod) { - Array pass_list; - auto check_calling_conv_func = [=](tir::PrimFunc func) { - Optional calling_conv = func->GetAttr(tvm::attr::kCallingConv); - return (!calling_conv) || (calling_conv.value() != CallingConv::kDeviceKernelLaunch); - }; - - pass_list.push_back(Filter(check_calling_conv_func)); // Filter by calling convention - // Python version added target_host as an attribute to every function here - pass_list.push_back(tir::transform::LowerTVMBuiltin()); - pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); - pass_list.push_back(tir::transform::LowerCustomDatatypes()); - pass_list.push_back(tir::transform::LowerIntrin()); - pass_list.push_back(tir::transform::CombineContextCall()); - - return transform::Sequential(pass_list)(mixed_mod); -} TVM_REGISTER_GLOBAL("driver.get_host_mod").set_body_typed([](IRModule mod) { - return GetHostMod(mod); + Target empty_target; + return OptimizeHostModule(mod, empty_target); }); -IRModule GetOptimizedHostModule(IRModule mixed_mod, Target target_host) { - auto host_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != - CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target_host), - tir::transform::LowerTVMBuiltin(), - tir::transform::LowerCustomDatatypes(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - tir::transform::CombineContextCall(), - }; +IRModule OptimizeHostModule(IRModule mixed_mod, Target target_host) { + Array host_pass_list; + host_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + CallingConv::kDeviceKernelLaunch; + })); + + if (target_host.defined()) { + host_pass_list.push_back(BindTarget(target_host)); + } + + host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); + host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + host_pass_list.push_back(tir::transform::LowerIntrin()); + host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + host_pass_list.push_back(tir::transform::CombineContextCall()); + auto host_module = transform::Sequential(host_pass_list); ICHECK(mixed_mod.defined()) << "This module must be defined"; return host_module(mixed_mod); } -IRModule GetDeviceOptimizedModule(IRModule mixed_mod, Target target) { - // device pipeline - auto device_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target), - tir::transform::LowerWarpMemory(), - tir::transform::Simplify(), - tir::transform::LowerCustomDatatypes(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - }; +IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target) { + Array device_pass_list; + device_pass_list.push_back(Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDeviceKernelLaunch; + })); + + if (target.defined()) { + device_pass_list.push_back(BindTarget(target)); + } + + device_pass_list.push_back(tir::transform::LowerWarpMemory()); + device_pass_list.push_back(tir::transform::Simplify()); + device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + device_pass_list.push_back(tir::transform::LowerIntrin()); + device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + auto device_opt_mod = transform::Sequential(device_pass_list); + auto mdevice = device_opt_mod(mixed_mod); return device_opt_mod(mixed_mod); From 73640e84696876ad276b5ceeff56b0f2a9498986 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Mon, 20 Sep 2021 15:15:58 +0300 Subject: [PATCH 08/35] Minor cleanup --- src/driver/driver_api.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 991c66496136..a4029c0826aa 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -374,9 +374,8 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") }); // Splits module into one to run on the device and one to run the host. E.g., CUDA, OpenCL etc -std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg, - const transform::PassContext& pass_ctx) { +std::pair SplitFuncsToDevHostMods(IRModule mod_mixed, const Target& target_arg, + const Target& target_host_arg) { Target target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); @@ -437,7 +436,7 @@ runtime::Module build(const Map& inputs_arg, const Target& tar for (const auto& it : inputs) { if (it.second.defined()) { - auto pair = SplitDevHostFuncs(it.second, it.first, target_host, pass_ctx); + auto pair = SplitFuncsToDevHostMods(it.second, it.first, target_host); auto& mhost = pair.first; auto& mdevice = pair.second; From d0ba8b8c8eda38025eda6bf0793cadefcc754a42 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Wed, 22 Sep 2021 15:53:46 +0300 Subject: [PATCH 09/35] Move target mangling to driver_api.cc --- python/tvm/driver/build_module.py | 103 ++++++++++++++++++------------ src/driver/driver_api.cc | 78 +++++++++++++++++----- 2 files changed, 122 insertions(+), 59 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 4409c4ae1b07..acc82c784725 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -87,6 +87,8 @@ def schedule_to_module( """ return ffi.schedule_to_module(sch, args, name, binds) +# cpp duck typing ? + def lower( inp: Union[schedule.Schedule, PrimFunc, IRModule], @@ -123,13 +125,15 @@ def lower( m : IRModule The result IRModule """ + # ffi.relay.lower_te_pass() if isinstance(inp, IRModule): return ffi.lower_module(inp, simple_mode) if isinstance(inp, PrimFunc): return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError( + "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) # TODO(@electriclilies): This should be moved into C++. @@ -162,7 +166,7 @@ def _build_for_device(input_mod, target, target_host): from tvm.driver import _ffi_api as _driver_ffi mod_mixed = _driver_ffi.get_mod_mixed(input_mod) device_mod = _driver_ffi.get_device_mod(mod_mixed) - host_mod = _driver_ffi.get_device_mod(mod_mixed) + host_mod = _driver_ffi.get_host_mod(mod_mixed) device_type = ndarray.device(target.kind.name, 0).device_type if device_type == ndarray.cpu(0).device_type and target_host == target: @@ -173,12 +177,12 @@ def _build_for_device(input_mod, target, target_host): ) # rt_mod_dev is runtime::Module so this can be moved out maybe? - rt_mod_dev = codegen.build_module(device_mod, target) if len(device_mod.functions) != 0 else None + rt_mod_dev = codegen.build_module(device_mod, target) if len( + device_mod.functions) != 0 else None # TIR module for the host, runtime module for devices? return host_mod, rt_mod_dev - def build( inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -192,7 +196,8 @@ def build( Parameters ---------- - inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] + inputs : Union[tvm.te.schedule.Schedule, + tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] The input to be built args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] @@ -257,7 +262,7 @@ def build( ---- See the note on :any:`tvm.target` on target string format. """ - + # Lowering if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)): # should this be te_lower instead? @@ -272,7 +277,9 @@ def build( f"Inputs must be Schedule, PrimFunc, IRModule or dict of target to IRModule, " f"but got {type(inputs)}." ) - + + # move to cpp from this point + # rest is codegen? # More target maps here... is inputs ever a map? # prepping and cutting module into chunks @@ -281,39 +288,49 @@ def build( # 2. Remove everywhere that takes map # after talking to xiyou he said a lot of difficulty was trying to maintain # map correctly so I may just remove that. - - if not isinstance(inputs, (dict, container.Map)): - target = Target.current() if target is None else target - target = target if target else "llvm" - target_input_mod = {target: input_mod} - else: - target_input_mod = inputs - - for tar, mod in target_input_mod.items(): - if not isinstance(tar, (str, Target)): - raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") - if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") - - target_input_mod, target_host = Target.check_and_update_host_consist( - target_input_mod, target_host - ) - - if not target_host: - for tar, mod in target_input_mod.items(): - tar = Target(tar) - device_type = ndarray.device(tar.kind.name, 0).device_type - if device_type == ndarray.cpu(0).device_type: - target_host = tar - break - # Why is this here? - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - - # why do we need to call chcek_and_update_host_consist again? - target_input_mod, target_host = Target.check_and_update_host_consist( - target_input_mod, target_host - ) + + # if not isinstance(inputs, (dict, container.Map)): + # target = Target.current() if target is None else target + # target = target if target else "llvm" + # target_input_mod = {target: input_mod} + # else: + # target_input_mod = inputs + + # # reshape tensor tests are failing without this one vm_reshape_tensor + + # for tar, mod in target_input_mod.items(): + # if not isinstance(tar, (str, Target)): + # raise ValueError( + # "The key of inputs must be str or " "Target when inputs is dict.") + # if not isinstance(mod, tvm.IRModule): + # raise ValueError( + # "inputs must be Schedule, IRModule," "or dict of str to IRModule.") + + # target_input_mod, target_host = Target.check_and_update_host_consist( + # target_input_mod, target_host + # ) + + # if not target_host: + # for tar, mod in target_input_mod.items(): + # tar = Target(tar) + # device_type = ndarray.device(tar.kind.name, 0).device_type + # if device_type == ndarray.cpu(0).device_type: + # target_host = tar + # break + # # Why is this here? + # # if not target_host: + # # target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + + # # why do we need to call chcek_and_update_host_consist again? + # target_input_mod, target_host = Target.check_and_update_host_consist( + # target_input_mod, target_host + # ) + from tvm.driver import _ffi_api as _driver_ffi + + target_input_mod = _driver_ffi.driver.target_mangling( + input_mod, target, target_host) + target_host = _driver_ffi.driver.host_target_mangling( + input_mod, target, target_host) mod_host_all = tvm.IRModule({}) @@ -347,13 +364,15 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module( + [rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module( + [rt_mod_host], target_host) else: to_return = rt_mod_host diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index a4029c0826aa..d56064c0252d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -398,6 +398,53 @@ std::pair SplitFuncsToDevHostMods(IRModule mod_mixed, const return {host_mod, device_mod}; } +std::pair TargetTypeMangling(const Map& inputs_arg, Target target, + Target target_host_arg) { + Target target_input_mod, target_host; + + target = !target.defined() ? target.Current() : target; + + std::vector device_modules; + Map inputs = inputs_arg; + target_host = target_host_arg; + + CheckAndUpdateHostConsistency(&inputs, &target_host); + + if (!target_host.defined()) { + for (const auto& it : inputs) { + if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) { + target_host = it.first; + break; + } + } + } + + if (!target_host.defined()) { + target_host = DefaultTargetHost(target_host); + } + CheckAndUpdateHostConsistency(&inputs, &target_host); + + return {target_input_mod, target_host}; +} + +// TVM_REGISTER_GLOBAL("driver.target_mangling") +// .set_body_typed([](const Map& inputs_arg, IRModule mod, Target target, +// Target target_host_arg) { +// return TargetTypeMangling(inputs_arg, mod, target, target_host_arg); +// }); + +TVM_REGISTER_GLOBAL("driver.target_mangling") + .set_body_typed([](const Map& inputs_arg, Target target, + Target target_host_arg) { + return TargetTypeMangling(inputs_arg, target, target_host_arg).first; + }); + +TVM_REGISTER_GLOBAL("driver.host_target_mangling") + .set_body_typed([](const Map& inputs_arg, Target target, + Target target_host_arg) { + return TargetTypeMangling(inputs_arg, target, target_host_arg).second; + }); + // Can we make this take one annotated IRModule? // // Build for heterogeneous execution. @@ -437,17 +484,17 @@ runtime::Module build(const Map& inputs_arg, const Target& tar for (const auto& it : inputs) { if (it.second.defined()) { auto pair = SplitFuncsToDevHostMods(it.second, it.first, target_host); - auto& mhost = pair.first; - auto& mdevice = pair.second; + auto& host_mod = pair.first; + auto& device_mod = pair.second; - ICHECK(mhost.defined()) << "The split host module must be defined"; + ICHECK(host_mod.defined()) << "The split host module must be defined"; ICHECK(mhost_all.defined()) << "The host module must be defined"; - mhost_all->Update(mhost); + mhost_all->Update(host_mod); - if (mdevice->functions.size() != 0) { - device_modules.push_back(codegen::Build(mdevice, it.first)); + if (device_mod->functions.size() != 0) { + device_modules.push_back(codegen::Build(device_mod, it.first)); } } } @@ -503,8 +550,8 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - // Python annotates all functions in the mod with the Target passed in here; I think we shouldn't - // have to do that. + // Python annotates all functions in the mod with the Target passed in here; I think we + // shouldn't have to do that. bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); @@ -528,19 +575,16 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { return opt_mixed(std::move(mixed_mod)); } -TVM_REGISTER_GLOBAL("driver.get_mod_mixed").set_body_typed([](IRModule mod) { - Target empty_target; - return OptimizeMixedModule(mod, empty_target); +TVM_REGISTER_GLOBAL("driver.get_mod_mixed").set_body_typed([](IRModule mod, Target target) { + return OptimizeMixedModule(mod, target); }); -TVM_REGISTER_GLOBAL("driver.get_device_mod").set_body_typed([](IRModule mod) { - Target empty_target; - return OptimizeDeviceModule(mod, empty_target); +TVM_REGISTER_GLOBAL("driver.get_device_mod").set_body_typed([](IRModule mod, Target target) { + return OptimizeDeviceModule(mod, target); }); -TVM_REGISTER_GLOBAL("driver.get_host_mod").set_body_typed([](IRModule mod) { - Target empty_target; - return OptimizeHostModule(mod, empty_target); +TVM_REGISTER_GLOBAL("driver.get_host_mod").set_body_typed([](IRModule mod, Target target_host) { + return OptimizeHostModule(mod, target_host); }); IRModule OptimizeHostModule(IRModule mixed_mod, Target target_host) { From 01b4ce3ca85b9ca5fdf2e0bb0353b3914d562a12 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Wed, 22 Sep 2021 18:40:02 +0300 Subject: [PATCH 10/35] Move more build utlities to cpp driver api --- python/tvm/driver/build_module.py | 43 +++++++++++++++-------------- src/driver/driver_api.cc | 45 ++++++++++++++++++++++++++----- 2 files changed, 62 insertions(+), 26 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index acc82c784725..2756e6fd7ac8 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -332,26 +332,29 @@ def build( target_host = _driver_ffi.driver.host_target_mangling( input_mod, target, target_host) - mod_host_all = tvm.IRModule({}) - - # This is building for target not device though.. - # From here through importing the device modules could probably be consolidated into one C++ function. - device_modules = [] - for tar, input_mod in target_input_mod.items(): - # mod_host is the module of the host.. bad name. - # Start with moving _build_for_device into c++ - mod_host, mdev = _build_for_device(input_mod, tar, target_host) - # what are we updating here? - mod_host_all.update(mod_host) - device_modules.append(mdev) - - # Generate a unified host module. - rt_mod_host = codegen.build_module(mod_host_all, target_host) - - # Import all modules. - for mdev in device_modules: - if mdev: - rt_mod_host.import_module(mdev) + # mod_host_all = tvm.IRModule({}) + + # # This is building for target not device though.. + # # From here through importing + # # the device modules could probably be consolidated into one C++ function. + # device_modules = [] + # for tar, input_mod in target_input_mod.items(): + # # mod_host is the module of the host.. bad name. + # mod_host, mdev = _build_for_device(input_mod, tar, target_host) + # # what are we updating here? + # mod_host_all.update(mod_host) + # device_modules.append(mdev) + + # # Generate a unified host module. + # rt_mod_host = codegen.build_module(mod_host_all, target_host) + + # # Import all modules. + # for mdev in device_modules: + # if mdev: + # rt_mod_host.import_module(mdev) + + rt_mod_host = _driver_ffi.driver.finalize_module( + target_input_mod, target_host) # stop moving to C++ here. if not isinstance(target_host, Target): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d56064c0252d..3b3324351ef6 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -427,12 +427,6 @@ std::pair TargetTypeMangling(const Map& inputs return {target_input_mod, target_host}; } -// TVM_REGISTER_GLOBAL("driver.target_mangling") -// .set_body_typed([](const Map& inputs_arg, IRModule mod, Target target, -// Target target_host_arg) { -// return TargetTypeMangling(inputs_arg, mod, target, target_host_arg); -// }); - TVM_REGISTER_GLOBAL("driver.target_mangling") .set_body_typed([](const Map& inputs_arg, Target target, Target target_host_arg) { @@ -445,6 +439,45 @@ TVM_REGISTER_GLOBAL("driver.host_target_mangling") return TargetTypeMangling(inputs_arg, target, target_host_arg).second; }); +runtime::Module finalizeModule(const Map& inputs_arg, Target host_target) { + std::vector device_modules; + + IRModule mhost_all = IRModule(Map()); + + ICHECK(mhost_all.defined()) << "The host module must be defined"; + + for (const auto& it : inputs_arg) { + if (it.second.defined()) { + auto pair = SplitFuncsToDevHostMods(it.second, it.first, host_target); + auto& host_mod = pair.first; + auto& device_mod = pair.second; + + ICHECK(host_mod.defined()) << "The split host module must be defined"; + + ICHECK(mhost_all.defined()) << "The host module must be defined"; + + mhost_all->Update(host_mod); + + if (device_mod->functions.size() != 0) { + device_modules.push_back(codegen::Build(device_mod, it.first)); + } + } + } + runtime::Module complete_mod = codegen::Build(mhost_all, host_target); + // Import all modules + for (const auto& it : device_modules) { + if (it.operator->()) { + complete_mod.Import(it); + } + } + return complete_mod; +} + +TVM_REGISTER_GLOBAL("driver.finalize_module") + .set_body_typed([](const Map& inputs_arg, Target host_target) { + return finalizeModule(inputs_arg, host_target); + }); + // Can we make this take one annotated IRModule? // // Build for heterogeneous execution. From 6176155930d0780cd2473cbeb3b6d886f5d6b98c Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 24 Sep 2021 14:04:01 +0300 Subject: [PATCH 11/35] [Build][WIP] Moving build utilities to C++ from Python --- include/tvm/target/codegen.h | 2 +- python/tvm/driver/build_module.py | 163 +++++++++++++----------------- python/tvm/relay/build_module.py | 22 ++-- python/tvm/target/codegen.py | 6 +- src/driver/driver_api.cc | 44 +++++--- src/relay/backend/vm/compiler.cc | 2 +- src/target/codegen.cc | 9 +- 7 files changed, 126 insertions(+), 122 deletions(-) diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index b2cab0e4bc45..0c82193c427a 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -45,7 +45,7 @@ using runtime::TVMRetValue; * \param target The target to be built. * \return The result runtime::Module. */ -runtime::Module Build(IRModule mod, Target target); +runtime::Module Codegen(IRModule mod, Target target); /*! * \brief Pack imported device library to a C file. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 2756e6fd7ac8..4a016ea25df8 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -136,7 +136,6 @@ def lower( "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) -# TODO(@electriclilies): This should be moved into C++. def _build_for_device(input_mod, target, target_host): """Build the lowered functions for a device with the given compilation target. @@ -154,19 +153,16 @@ def _build_for_device(input_mod, target, target_host): Returns ------- - fhost : IRModule + host_mod : IRModule The host IRModule. - mdev : tvm.module + device_mod : tvm.module A module that contains device code. """ - # Ideally delete check_and_update_host_consist from here - # target, target_host = Target.check_and_update_host_consist(target, target_host) - # Point 1 from tvm.driver import _ffi_api as _driver_ffi - mod_mixed = _driver_ffi.get_mod_mixed(input_mod) - device_mod = _driver_ffi.get_device_mod(mod_mixed) - host_mod = _driver_ffi.get_host_mod(mod_mixed) + mod_mixed = _driver_ffi.get_mod_mixed(input_mod, target) + device_mod = _driver_ffi.get_device_mod(mod_mixed, target) + host_mod = _driver_ffi.get_host_mod(mod_mixed, target_host) device_type = ndarray.device(target.kind.name, 0).device_type if device_type == ndarray.cpu(0).device_type and target_host == target: @@ -193,19 +189,14 @@ def build( ): """Build a function with arguments as signature. Code will be generated for devices coupled with target information. - Parameters ---------- - inputs : Union[tvm.te.schedule.Schedule, - tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] + inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] The input to be built - args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] The argument lists to the function. - target : Optional[Union[str, Target]] The target and option of the compilation. - target_host : Optional[Union[str, Target]] Host compilation target, if target is device. When TVM compiles device specific program such as CUDA, @@ -214,27 +205,21 @@ def build( target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, otherwise a stackvm intepreter is used. - name : Optional[str] The name of result function. - binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] Dictionary that maps the binding of symbolic buffer to Tensor. By default, a new buffer is created for each tensor in the argument. - Returns ------- ret : tvm.module A module that combines both host and device code. - Examples ________ There are two typical example uses of this function depending on the type of the argument `inputs`: 1. it is an IRModule. - .. code-block:: python - n = 2 A = te.placeholder((n,), name='A') B = te.placeholder((n,), name='B') @@ -242,11 +227,8 @@ def build( s = tvm.te.create_schedule(C.op) m = tvm.lower(s, [A, B, C], name="test_add") rt_mod = tvm.build(m, target="llvm") - 2. it is a dict of compilation target to IRModule. - .. code-block:: python - n = 2 A = te.placeholder((n,), name='A') B = te.placeholder((n,), name='B') @@ -257,7 +239,6 @@ def build( m1 = tvm.lower(s1, [A, B, C], name="test_add1") m2 = tvm.lower(s2, [A, B, C], name="test_add2") rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm") - Note ---- See the note on :any:`tvm.target` on target string format. @@ -278,8 +259,6 @@ def build( f"but got {type(inputs)}." ) - # move to cpp from this point - # rest is codegen? # More target maps here... is inputs ever a map? # prepping and cutting module into chunks @@ -289,72 +268,62 @@ def build( # after talking to xiyou he said a lot of difficulty was trying to maintain # map correctly so I may just remove that. - # if not isinstance(inputs, (dict, container.Map)): - # target = Target.current() if target is None else target - # target = target if target else "llvm" - # target_input_mod = {target: input_mod} - # else: - # target_input_mod = inputs - - # # reshape tensor tests are failing without this one vm_reshape_tensor - - # for tar, mod in target_input_mod.items(): - # if not isinstance(tar, (str, Target)): - # raise ValueError( - # "The key of inputs must be str or " "Target when inputs is dict.") - # if not isinstance(mod, tvm.IRModule): - # raise ValueError( - # "inputs must be Schedule, IRModule," "or dict of str to IRModule.") - - # target_input_mod, target_host = Target.check_and_update_host_consist( - # target_input_mod, target_host - # ) - - # if not target_host: - # for tar, mod in target_input_mod.items(): - # tar = Target(tar) - # device_type = ndarray.device(tar.kind.name, 0).device_type - # if device_type == ndarray.cpu(0).device_type: - # target_host = tar - # break - # # Why is this here? - # # if not target_host: - # # target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - - # # why do we need to call chcek_and_update_host_consist again? - # target_input_mod, target_host = Target.check_and_update_host_consist( - # target_input_mod, target_host - # ) - from tvm.driver import _ffi_api as _driver_ffi - - target_input_mod = _driver_ffi.driver.target_mangling( - input_mod, target, target_host) - target_host = _driver_ffi.driver.host_target_mangling( - input_mod, target, target_host) - - # mod_host_all = tvm.IRModule({}) - - # # This is building for target not device though.. - # # From here through importing - # # the device modules could probably be consolidated into one C++ function. - # device_modules = [] - # for tar, input_mod in target_input_mod.items(): - # # mod_host is the module of the host.. bad name. - # mod_host, mdev = _build_for_device(input_mod, tar, target_host) - # # what are we updating here? - # mod_host_all.update(mod_host) - # device_modules.append(mdev) - - # # Generate a unified host module. - # rt_mod_host = codegen.build_module(mod_host_all, target_host) - - # # Import all modules. - # for mdev in device_modules: - # if mdev: - # rt_mod_host.import_module(mdev) - - rt_mod_host = _driver_ffi.driver.finalize_module( - target_input_mod, target_host) + if not isinstance(inputs, (dict, container.Map)): + target = Target.current() if target is None else target + target = target if target else "llvm" + target_input_mod = {target: input_mod} + else: + target_input_mod = inputs + + for tar, mod in target_input_mod.items(): + if not isinstance(tar, (str, Target)): + raise ValueError( + "The key of inputs must be str or " "Target when inputs is dict.") + if not isinstance(mod, tvm.IRModule): + raise ValueError( + "inputs must be Schedule, IRModule," "or dict of str to IRModule.") + + target_input_mod, target_host = Target.check_and_update_host_consist( + target_input_mod, target_host + ) + + if not target_host: + for tar, mod in target_input_mod.items(): + tar = Target(tar) + device_type = ndarray.device(tar.kind.name, 0).device_type + if device_type == ndarray.cpu(0).device_type: + target_host = tar + break + # Why is this here? + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + + # why do we need to call chcek_and_update_host_consist again? + target_input_mod, target_host = Target.check_and_update_host_consist( + target_input_mod, target_host + ) + + mod_host_all = tvm.IRModule({}) + + # This is building for target not device though.. + # From here through importing the device modules could probably be consolidated into one C++ function. + device_modules = [] + for tar, input_mod in target_input_mod.items(): + # mod_host is the module of the host.. bad name. + # Start with moving _build_for_device into c++ + mod_host, mdev = _build_for_device(input_mod, tar, target_host) + # what are we updating here? + mod_host_all.update(mod_host) + device_modules.append(mdev) + + print(mod_host_all) + # Generate a unified host module. + rt_mod_host = codegen.build_module(mod_host_all, target_host) + + # Import all modules. + for mdev in device_modules: + if mdev: + rt_mod_host.import_module(mdev) # stop moving to C++ here. if not isinstance(target_host, Target): @@ -378,11 +347,14 @@ def build( [rt_mod_host], target_host) else: to_return = rt_mod_host + print(target) + print("END OF BUILD") return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) - # What is OperatorModule and how is it different from runtime::Module + + class OperatorModule(Module): """Wraps the Module returned by tvm.build() and captures additional outputs of that function.""" @@ -391,11 +363,14 @@ def from_module(cls, mod, **kwargs): # NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward. # If an exception occurs in cls.__init__, handle will be deleted. For this reason, # set mod.handle to None. + print("from module conv") handle = mod.handle + print(mod.handle) mod.handle = None return cls(handle, **kwargs) def __init__(self, handle, ir_module_by_target=None, name=None): super(OperatorModule, self).__init__(handle) self.ir_module_by_target = ir_module_by_target + print(name) self.name = name diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index fb9a8b78127f..94c5d7eca9a7 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -63,7 +63,8 @@ def build_target_by_device_type_map(target): tgts = {} if isinstance(target, (str, Target)): - dev_type = tvm_expr.IntImm("int32", _nd.device(str(target)).device_type) + dev_type = tvm_expr.IntImm( + "int32", _nd.device(str(target)).device_type) tgts[dev_type] = Target(target) elif isinstance(target, dict): for dev, tgt in target.items(): @@ -93,8 +94,6 @@ class BuildModule(object): """ def __init__(self): - # This is implicitly calling GetFunction in RelayBuildModuleNode which then calls the correct - # executor's version of that function. self.mod = _build_module._BuildModule() self._get_graph_json = self.mod["get_graph_json"] self._get_module = self.mod["get_module"] @@ -125,7 +124,7 @@ def build( to setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. params : dict of str to NDArray Input parameters to the graph that do not change @@ -253,7 +252,8 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo This wrapper is suitable to be used from other programming languages as the runtime::Module can be freely passed between language boundaries. """ - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.check_and_update_host_consist( + target, target_host) return build(mod, target, params=params, mod_name=mod_name).module @@ -284,6 +284,8 @@ def get_executor_from_target(target, target_host): return executor +# Which build is this one... Relay --> graph executor +# can params being parsed during run-time? def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"): # fmt: off # pylint: disable=line-too-long @@ -305,7 +307,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm interpreter is used. params : dict of str to NDArray Input parameters to the graph that do not change @@ -338,7 +340,8 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" if isinstance(target_host, (str, Target)): target_host = Target(target_host) elif target_host: - raise ValueError("target host must be the type of str, " + "tvm.target.Target, or None") + raise ValueError( + "target host must be the type of str, " + "tvm.target.Target, or None") target, target_host = Target.check_and_update_host_consist( target, target_host, target_is_dict_key=False @@ -454,7 +457,7 @@ def bind_params_by_name(func, params): class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. - This executor is used for debug and testing purpoes. + This executor is used for debug and testing purposes. Parameters ---------- @@ -495,7 +498,8 @@ def _unflatten(flat_iter, cur_type): field = _unflatten(flat_iter, field_type) fields.append(field) return fields - raise ValueError("Return type", ret_type, "contains unsupported type", cur_type) + raise ValueError("Return type", ret_type, + "contains unsupported type", cur_type) def _graph_wrapper(*args, **kwargs): args = self._convert_args(self.mod["main"], args, kwargs) diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 21d154a54279..65e8a2959a10 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -38,7 +38,8 @@ def build_module(mod, target): print("In codegen build module") # Where is Build defined? can't find it, can only find target.build.something... target = Target(target) if isinstance(target, str) else target - return _ffi_api.Build(mod, target) + + return _ffi_api.Codegen(mod, target) def llvm_lookup_intrinsic_id(name): @@ -75,4 +76,5 @@ def llvm_version_major(allow_none=False): except AttributeError: if allow_none: return None - raise RuntimeError("LLVM version is not available, please check if you build with LLVM") + raise RuntimeError( + "LLVM version is not available, please check if you build with LLVM") diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 3b3324351ef6..7bf3b7f5e671 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -42,6 +42,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); using runtime::PackedFunc; @@ -155,6 +156,13 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); } +transform::Pass AnnotateEntryFunc(bool b) { + auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {}); +} + template transform::Pass Filter(FCond fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { @@ -459,11 +467,11 @@ runtime::Module finalizeModule(const Map& inputs_arg, Target h mhost_all->Update(host_mod); if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Build(device_mod, it.first)); + device_modules.push_back(codegen::Codegen(device_mod, it.first)); } } } - runtime::Module complete_mod = codegen::Build(mhost_all, host_target); + runtime::Module complete_mod = codegen::Codegen(mhost_all, host_target); // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { @@ -527,12 +535,12 @@ runtime::Module build(const Map& inputs_arg, const Target& tar mhost_all->Update(host_mod); if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Build(device_mod, it.first)); + device_modules.push_back(codegen::Codegen(device_mod, it.first)); } } } - runtime::Module mhost = codegen::Build(mhost_all, target_host); + runtime::Module mhost = codegen::Codegen(mhost_all, target_host); // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { @@ -576,8 +584,13 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); Array mixed_pass_list; - if (target.defined()) { - mixed_pass_list.push_back(BindTarget(target)); + + mixed_pass_list.push_back(BindTarget(target)); + + bool is_entry_func = false; + if (mixed_mod->functions.size() == 1) { + is_entry_func = pass_ctx->GetConfig("tir.is_entry_func", Bool(true)).value(); + mixed_pass_list.push_back(AnnotateEntryFunc(is_entry_func)); } mixed_pass_list.push_back(tir::transform::VerifyMemory()); @@ -586,12 +599,23 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { // Python annotates all functions in the mod with the Target passed in here; I think we // shouldn't have to do that. + printf("\n"); + printf("***************************** \n"); + printf("%d\n", is_entry_func); + printf("***************************** \n"); + + // if (is_entry_func) { + // mixed_mod = WithAttr(std::move(mixed_mod), "tir.is_entry_func", Bool(true)); + // } + bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); if (detect_global_barrier) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } + // mixed_mod->GetAttr + mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); @@ -627,9 +651,7 @@ IRModule OptimizeHostModule(IRModule mixed_mod, Target target_host) { CallingConv::kDeviceKernelLaunch; })); - if (target_host.defined()) { - host_pass_list.push_back(BindTarget(target_host)); - } + host_pass_list.push_back(BindTarget(target_host)); host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); @@ -650,9 +672,7 @@ IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target) { CallingConv::kDeviceKernelLaunch; })); - if (target.defined()) { - device_pass_list.push_back(BindTarget(target)); - } + device_pass_list.push_back(BindTarget(target)); device_pass_list.push_back(tir::transform::LowerWarpMemory()); device_pass_list.push_back(tir::transform::Simplify()); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b3eab91d202c..a7dad34a9fc7 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -830,7 +830,7 @@ class VMFunctionCompiler : ExprFunctor { /*! * \brief Compile a pattern match expression - * It first converts the pattern match expression into a desicision tree, the condition + * It first converts the pattern match expression into a decision tree, the condition * could be object comparison or variable binding. If any of the condition fails in a clause, * the decision tree switches to check the conditions of next clause and so on. If no clause * matches the value, a fatal node is inserted. diff --git a/src/target/codegen.cc b/src/target/codegen.cc index ff019691d252..3ef0dc7dc6fa 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -50,14 +50,17 @@ runtime::Module BuildFunc(PrimFunc func, Target target) { } -// Maybe don't need this here... +// Maybe don't need this here... runtime::Module BuildTargetMap(Map, Target Target) { } */ // Leave this -- why not called codegen? // TODO(@electriclilies): Rename this to Codegen (dont get consensus just try to put it in) -runtime::Module Build(IRModule mod, Target target) { +// i agree +runtime::Module Codegen(IRModule mod, Target target) { + printf("IN CODE Gen build !!!\n"); + if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) .value()) { @@ -329,7 +332,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, } // Where build is registered -TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build); +TVM_REGISTER_GLOBAL("target.Codegen").set_body_typed(Codegen); // Export two auxiliary function to the runtime namespace. TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); From 09aaf880cb562e1453beefad940724746a16cf9b Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 24 Sep 2021 14:20:18 +0300 Subject: [PATCH 12/35] [Build] Remove comments --- apps/ios_rpc/README.md | 4 +- include/tvm/runtime/module.h | 2 - python/tvm/driver/build_module.py | 13 ----- src/ir/module.cc | 1 - src/target/codegen.cc | 21 -------- tests/cpp/relay_build_module_test.cc | 3 -- .../unittest/test_runtime_heterogeneous.py | 49 ++++++++++++------- 7 files changed, 34 insertions(+), 59 deletions(-) diff --git a/apps/ios_rpc/README.md b/apps/ios_rpc/README.md index c268d15d0179..2d9cc52dc0ad 100644 --- a/apps/ios_rpc/README.md +++ b/apps/ios_rpc/README.md @@ -79,7 +79,7 @@ You can get value of your `team_id` in the following ways: select target `tvmrpc`. At the bottom of this panel go to `Signing & Capabilities` tab and in the field `Team` select your local developer profile (`Your Name (Personal Team)`). - + On the first run of the application you may see message `Could not launch "tvmrpc"` in the XCode and message `Untrusted Developer` on your device. In this case it will be necessary to check the certificate. Open @@ -210,7 +210,7 @@ model and execute it on the target device. For this purpose we will use ```shell python3 tests/ios_rpc_test.py --host --port 9190 --mode "tracker" ``` -The output will be the same as in section +The output will be the same as in section [Standalone RPC](#standalone-rpc). ## Communication without Wi-Fi and speed up in case of slow Wi-Fi diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 8560f7b399c7..71be8d218d2d 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -134,7 +134,6 @@ class TVM_DLL ModuleNode : public Object { * If the function need resource from the module(e.g. late linking), * it should capture sptr_to_self. */ - // This is pure virtual which means its only instantiated by subclasses of ModuleNode. virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) = 0; /*! @@ -241,7 +240,6 @@ constexpr const char* tvm_entrypoint_suffix = "run"; inline void Module::Import(Module other) { return (*this)->Import(other); } -// seems questionable to provide a mutable pointer into the runtime module? inline ModuleNode* Module::operator->() { return static_cast(get_mutable()); } inline const ModuleNode* Module::operator->() const { diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 4a016ea25df8..baadb0acdae9 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -87,8 +87,6 @@ def schedule_to_module( """ return ffi.schedule_to_module(sch, args, name, binds) -# cpp duck typing ? - def lower( inp: Union[schedule.Schedule, PrimFunc, IRModule], @@ -259,15 +257,6 @@ def build( f"but got {type(inputs)}." ) - # rest is codegen? - # More target maps here... is inputs ever a map? - # prepping and cutting module into chunks - - # 1. get into c++ - # 2. Remove everywhere that takes map - # after talking to xiyou he said a lot of difficulty was trying to maintain - # map correctly so I may just remove that. - if not isinstance(inputs, (dict, container.Map)): target = Target.current() if target is None else target target = target if target else "llvm" @@ -294,11 +283,9 @@ def build( if device_type == ndarray.cpu(0).device_type: target_host = tar break - # Why is this here? if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - # why do we need to call chcek_and_update_host_consist again? target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host ) diff --git a/src/ir/module.cc b/src/ir/module.cc index 606991e59982..15c441d61a23 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -522,7 +522,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "IRModule(" << node->functions << ")"; - // << "attrs = " << node->attrs; }); } // namespace tvm diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 3ef0dc7dc6fa..4ec17adc5bd4 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -40,27 +40,7 @@ namespace tvm { namespace codegen { -/* -runtime::Module BuildSchedule(Schedule sch, Target target) { - // call lower schedule to module - // call build?? -} - -runtime::Module BuildFunc(PrimFunc func, Target target) { - -} - -// Maybe don't need this here... -runtime::Module BuildTargetMap(Map, Target Target) { - -} -*/ -// Leave this -- why not called codegen? -// TODO(@electriclilies): Rename this to Codegen (dont get consensus just try to put it in) -// i agree runtime::Module Codegen(IRModule mod, Target target) { - printf("IN CODE Gen build !!!\n"); - if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) .value()) { @@ -331,7 +311,6 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, return (*codegen_f)(blob_byte_array, system_lib, target_triple); } -// Where build is registered TVM_REGISTER_GLOBAL("target.Codegen").set_body_typed(Codegen); // Export two auxiliary function to the runtime namespace. diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 8aa603d1a818..ebb2867e7b69 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -66,9 +66,6 @@ TVM_REGISTER_GLOBAL("relay.backend.lower_call") OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); auto impl = strategy->specializations[0]->implementations[0]; auto outs = impl.Compute(call->attrs, inputs, out_type); - // Using make_LoweredOutput here - // wait ok is this the python LoweredOutput or the C++ LoweredOutput? - // This is a test. so its probably ok auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput"); if (!f) { LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py index c1d1267d5ea1..494f6351fdcf 100644 --- a/tests/python/unittest/test_runtime_heterogeneous.py +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -152,7 +152,8 @@ def check_device(device, target_device): ) target = topi.cpp.TEST_create_target(device) schedule_add = topi.cpp.cuda.schedule_injective(target, [elemwise_add]) - lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add], name="elemwise_add") + lower_add = tvm.lower( + schedule_add, [tensor_a, tensor_b, elemwise_add], name="elemwise_add") # Insert copy. Neither compute nor schedule is required for the copy # node. The compute will be performed at runtime which is just data @@ -166,7 +167,8 @@ def check_device(device, target_device): ) schedule_sub = te.create_schedule(elemwise_sub.op) lower_sub = tvm.lower( - schedule_sub, [tensor_copy, tensor_c, elemwise_sub], name="elemwise_sub" + schedule_sub, [tensor_copy, tensor_c, + elemwise_sub], name="elemwise_sub" ) target_flist = {target_device: lower_add, target_host: lower_sub} @@ -175,9 +177,12 @@ def check_device(device, target_device): dev = [host_dev, device_dev] mod = graph_executor.create(graph, mhost, dev) params = {} - params["A"] = tensor_a = np.random.uniform(size=shape).astype(tensor_a.dtype) - params["B"] = tensor_b = np.random.uniform(size=shape).astype(tensor_b.dtype) - params["C"] = tensor_c = np.random.uniform(size=shape).astype(tensor_c.dtype) + params["A"] = tensor_a = np.random.uniform( + size=shape).astype(tensor_a.dtype) + params["B"] = tensor_b = np.random.uniform( + size=shape).astype(tensor_b.dtype) + params["C"] = tensor_c = np.random.uniform( + size=shape).astype(tensor_c.dtype) mod.set_input(**params) mod.run() out = mod.get_output(0, tvm.nd.empty(shape)) @@ -380,13 +385,17 @@ def check_device(device, target_device): shape, lambda *i: copy_sub_add(*i) + tensor_d(*i), name="elemwise_add1" ) target = topi.cpp.TEST_create_target(device) - add_schedule0 = topi.cpp.cuda.schedule_injective(target, [elemwise_add0]) + add_schedule0 = topi.cpp.cuda.schedule_injective( + target, [elemwise_add0]) lower_add0 = tvm.lower( - add_schedule0, [tensor_a, tensor_b, elemwise_add0], name="elemwise_add0" + add_schedule0, [tensor_a, tensor_b, + elemwise_add0], name="elemwise_add0" ) - add_schedule1 = topi.cpp.cuda.schedule_injective(target, [elemwise_add1]) + add_schedule1 = topi.cpp.cuda.schedule_injective( + target, [elemwise_add1]) lower_add1 = tvm.lower( - add_schedule1, [tensor_d, copy_sub_add, elemwise_add1], name="elemwise_add1" + add_schedule1, [tensor_d, copy_sub_add, + elemwise_add1], name="elemwise_add1" ) # Create module for sub whose target is the host. tensor_c = te.placeholder(shape, name="C") @@ -395,27 +404,32 @@ def check_device(device, target_device): ) sub_schedule = te.create_schedule(elemwise_sub.op) lower_sub = tvm.lower( - sub_schedule, [copy_add_sub, tensor_c, elemwise_sub], name="elemwise_sub" + sub_schedule, [copy_add_sub, tensor_c, + elemwise_sub], name="elemwise_sub" ) lower_add0.update(lower_add1) - # TODO: fix this target_flist = {target_device: lower_add0, target_host: lower_sub} target = tvm.target.Target(target, target_host) mhost = tvm.build(target_flist, target=target) dev = [host_dev, device_dev] params = {} - params["A"] = tensor_a = np.random.uniform(size=shape).astype(tensor_a.dtype) - params["B"] = tensor_b = np.random.uniform(size=shape).astype(tensor_b.dtype) - params["C"] = tensor_c = np.random.uniform(size=shape).astype(tensor_c.dtype) - params["D"] = tensor_d = np.random.uniform(size=shape).astype(tensor_d.dtype) + params["A"] = tensor_a = np.random.uniform( + size=shape).astype(tensor_a.dtype) + params["B"] = tensor_b = np.random.uniform( + size=shape).astype(tensor_b.dtype) + params["C"] = tensor_c = np.random.uniform( + size=shape).astype(tensor_c.dtype) + params["D"] = tensor_d = np.random.uniform( + size=shape).astype(tensor_d.dtype) def check_verify(): mod = graph_executor.create(graph, mhost, dev) mod.set_input(**params) mod.run() out = mod.get_output(0, tvm.nd.empty(shape)) - np.testing.assert_equal(out.numpy(), tensor_a + tensor_b - tensor_c + tensor_d) + np.testing.assert_equal( + out.numpy(), tensor_a + tensor_b - tensor_c + tensor_d) def check_load_module(): temp = utils.tempdir() @@ -429,7 +443,8 @@ def check_load_module(): mod.set_input(**params) mod.run() out = mod.get_output(0, tvm.nd.empty(shape)) - np.testing.assert_equal(out.numpy(), tensor_a + tensor_b - tensor_c + tensor_d) + np.testing.assert_equal( + out.numpy(), tensor_a + tensor_b - tensor_c + tensor_d) check_verify() check_load_module() From 311632be39b9208e3f9402cce7c1116fa82d84fd Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 24 Sep 2021 14:28:49 +0300 Subject: [PATCH 13/35] [lint] Pass black --- python/tvm/driver/build_module.py | 22 ++++----- python/tvm/relay/backend/compile_engine.py | 1 + python/tvm/relay/build_module.py | 12 ++--- python/tvm/target/codegen.py | 3 +- python/tvm/target/target.py | 3 +- src/relay/backend/build_module.cc | 1 - .../unittest/test_runtime_heterogeneous.py | 48 +++++++------------ 7 files changed, 33 insertions(+), 57 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index baadb0acdae9..8d35f8b2641c 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -130,8 +130,7 @@ def lower( return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError( - "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def _build_for_device(input_mod, target, target_host): @@ -158,6 +157,7 @@ def _build_for_device(input_mod, target, target_host): A module that contains device code. """ from tvm.driver import _ffi_api as _driver_ffi + mod_mixed = _driver_ffi.get_mod_mixed(input_mod, target) device_mod = _driver_ffi.get_device_mod(mod_mixed, target) host_mod = _driver_ffi.get_host_mod(mod_mixed, target_host) @@ -171,8 +171,9 @@ def _build_for_device(input_mod, target, target_host): ) # rt_mod_dev is runtime::Module so this can be moved out maybe? - rt_mod_dev = codegen.build_module(device_mod, target) if len( - device_mod.functions) != 0 else None + rt_mod_dev = ( + codegen.build_module(device_mod, target) if len(device_mod.functions) != 0 else None + ) # TIR module for the host, runtime module for devices? return host_mod, rt_mod_dev @@ -266,11 +267,9 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError( - "The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError( - "inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -323,15 +322,13 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) else: to_return = rt_mod_host print(target) @@ -339,6 +336,7 @@ def build( return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) + # What is OperatorModule and how is it different from runtime::Module diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index a2b7306eab3c..071d9e6973d9 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -277,6 +277,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) # Returns LoweredOutput + @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 94c5d7eca9a7..905b8e2d9f9d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -63,8 +63,7 @@ def build_target_by_device_type_map(target): tgts = {} if isinstance(target, (str, Target)): - dev_type = tvm_expr.IntImm( - "int32", _nd.device(str(target)).device_type) + dev_type = tvm_expr.IntImm("int32", _nd.device(str(target)).device_type) tgts[dev_type] = Target(target) elif isinstance(target, dict): for dev, tgt in target.items(): @@ -252,8 +251,7 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo This wrapper is suitable to be used from other programming languages as the runtime::Module can be freely passed between language boundaries. """ - target, target_host = Target.check_and_update_host_consist( - target, target_host) + target, target_host = Target.check_and_update_host_consist(target, target_host) return build(mod, target, params=params, mod_name=mod_name).module @@ -340,8 +338,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" if isinstance(target_host, (str, Target)): target_host = Target(target_host) elif target_host: - raise ValueError( - "target host must be the type of str, " + "tvm.target.Target, or None") + raise ValueError("target host must be the type of str, " + "tvm.target.Target, or None") target, target_host = Target.check_and_update_host_consist( target, target_host, target_is_dict_key=False @@ -498,8 +495,7 @@ def _unflatten(flat_iter, cur_type): field = _unflatten(flat_iter, field_type) fields.append(field) return fields - raise ValueError("Return type", ret_type, - "contains unsupported type", cur_type) + raise ValueError("Return type", ret_type, "contains unsupported type", cur_type) def _graph_wrapper(*args, **kwargs): args = self._convert_args(self.mod["main"], args, kwargs) diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 65e8a2959a10..078c7995a65d 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -76,5 +76,4 @@ def llvm_version_major(allow_none=False): except AttributeError: if allow_none: return None - raise RuntimeError( - "LLVM version is not available, please check if you build with LLVM") + raise RuntimeError("LLVM version is not available, please check if you build with LLVM") diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index e7e5e0277a8e..dc85c425c22a 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -174,8 +174,7 @@ def list_kinds(): """Returns the list of available target names.""" return list(_ffi_api.ListTargetKinds()) - - # TODO: make this return IRModule? idk it seems + # TODO: make this return IRModule? idk it seems @staticmethod def check_and_update_host_consist(target, host=None, target_is_dict_key=True): """A helper function that merges a legacy "target, target_host" pair, then returns diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 2efbf2bdef4c..43393efec569 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -195,7 +195,6 @@ class RelayBuildModule : public runtime::ModuleNode { } }); } else if (name == "get_irmodule") { - // GetIRModule just calls the GetFunction of the executor, ends up in a DIFFERENT ModuleNode's GetFunction (oof) return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->executor_codegen_->GetIRModule(); }); diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py index 494f6351fdcf..167f61d748c2 100644 --- a/tests/python/unittest/test_runtime_heterogeneous.py +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -152,8 +152,7 @@ def check_device(device, target_device): ) target = topi.cpp.TEST_create_target(device) schedule_add = topi.cpp.cuda.schedule_injective(target, [elemwise_add]) - lower_add = tvm.lower( - schedule_add, [tensor_a, tensor_b, elemwise_add], name="elemwise_add") + lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add], name="elemwise_add") # Insert copy. Neither compute nor schedule is required for the copy # node. The compute will be performed at runtime which is just data @@ -167,8 +166,7 @@ def check_device(device, target_device): ) schedule_sub = te.create_schedule(elemwise_sub.op) lower_sub = tvm.lower( - schedule_sub, [tensor_copy, tensor_c, - elemwise_sub], name="elemwise_sub" + schedule_sub, [tensor_copy, tensor_c, elemwise_sub], name="elemwise_sub" ) target_flist = {target_device: lower_add, target_host: lower_sub} @@ -177,12 +175,9 @@ def check_device(device, target_device): dev = [host_dev, device_dev] mod = graph_executor.create(graph, mhost, dev) params = {} - params["A"] = tensor_a = np.random.uniform( - size=shape).astype(tensor_a.dtype) - params["B"] = tensor_b = np.random.uniform( - size=shape).astype(tensor_b.dtype) - params["C"] = tensor_c = np.random.uniform( - size=shape).astype(tensor_c.dtype) + params["A"] = tensor_a = np.random.uniform(size=shape).astype(tensor_a.dtype) + params["B"] = tensor_b = np.random.uniform(size=shape).astype(tensor_b.dtype) + params["C"] = tensor_c = np.random.uniform(size=shape).astype(tensor_c.dtype) mod.set_input(**params) mod.run() out = mod.get_output(0, tvm.nd.empty(shape)) @@ -385,17 +380,13 @@ def check_device(device, target_device): shape, lambda *i: copy_sub_add(*i) + tensor_d(*i), name="elemwise_add1" ) target = topi.cpp.TEST_create_target(device) - add_schedule0 = topi.cpp.cuda.schedule_injective( - target, [elemwise_add0]) + add_schedule0 = topi.cpp.cuda.schedule_injective(target, [elemwise_add0]) lower_add0 = tvm.lower( - add_schedule0, [tensor_a, tensor_b, - elemwise_add0], name="elemwise_add0" + add_schedule0, [tensor_a, tensor_b, elemwise_add0], name="elemwise_add0" ) - add_schedule1 = topi.cpp.cuda.schedule_injective( - target, [elemwise_add1]) + add_schedule1 = topi.cpp.cuda.schedule_injective(target, [elemwise_add1]) lower_add1 = tvm.lower( - add_schedule1, [tensor_d, copy_sub_add, - elemwise_add1], name="elemwise_add1" + add_schedule1, [tensor_d, copy_sub_add, elemwise_add1], name="elemwise_add1" ) # Create module for sub whose target is the host. tensor_c = te.placeholder(shape, name="C") @@ -404,8 +395,7 @@ def check_device(device, target_device): ) sub_schedule = te.create_schedule(elemwise_sub.op) lower_sub = tvm.lower( - sub_schedule, [copy_add_sub, tensor_c, - elemwise_sub], name="elemwise_sub" + sub_schedule, [copy_add_sub, tensor_c, elemwise_sub], name="elemwise_sub" ) lower_add0.update(lower_add1) @@ -414,22 +404,17 @@ def check_device(device, target_device): mhost = tvm.build(target_flist, target=target) dev = [host_dev, device_dev] params = {} - params["A"] = tensor_a = np.random.uniform( - size=shape).astype(tensor_a.dtype) - params["B"] = tensor_b = np.random.uniform( - size=shape).astype(tensor_b.dtype) - params["C"] = tensor_c = np.random.uniform( - size=shape).astype(tensor_c.dtype) - params["D"] = tensor_d = np.random.uniform( - size=shape).astype(tensor_d.dtype) + params["A"] = tensor_a = np.random.uniform(size=shape).astype(tensor_a.dtype) + params["B"] = tensor_b = np.random.uniform(size=shape).astype(tensor_b.dtype) + params["C"] = tensor_c = np.random.uniform(size=shape).astype(tensor_c.dtype) + params["D"] = tensor_d = np.random.uniform(size=shape).astype(tensor_d.dtype) def check_verify(): mod = graph_executor.create(graph, mhost, dev) mod.set_input(**params) mod.run() out = mod.get_output(0, tvm.nd.empty(shape)) - np.testing.assert_equal( - out.numpy(), tensor_a + tensor_b - tensor_c + tensor_d) + np.testing.assert_equal(out.numpy(), tensor_a + tensor_b - tensor_c + tensor_d) def check_load_module(): temp = utils.tempdir() @@ -443,8 +428,7 @@ def check_load_module(): mod.set_input(**params) mod.run() out = mod.get_output(0, tvm.nd.empty(shape)) - np.testing.assert_equal( - out.numpy(), tensor_a + tensor_b - tensor_c + tensor_d) + np.testing.assert_equal(out.numpy(), tensor_a + tensor_b - tensor_c + tensor_d) check_verify() check_load_module() From 0c28839b7c810268fc10af06f3d8b6f9e74780e0 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 24 Sep 2021 15:14:42 +0300 Subject: [PATCH 14/35] More formating --- include/tvm/driver/driver_api.h | 6 ++++-- python/tvm/driver/build_module.py | 16 +--------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 1ac3e703695e..bf136a273887 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -46,7 +46,9 @@ namespace tvm { /*! * \brief Returns the optimized IRModule for original fused module (pre split) that contains device - * and host code. \param mixed_mod The original mixed module. \param target The device Target. + * and host code. + * \param mixed_mod The original mixed module. + * \param target The device Target. * \return The result optimized mixed module. */ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target); @@ -62,7 +64,7 @@ IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target); /*! * \brief Returns the optimized IRModule for the host Target after device/host from mixed module. * \param mixed_mod The optimized mixed module. - * \param target The host Target. + * \param target_host The host Target. * \return The result optimized host module. */ IRModule OptimizeHostModule(IRModule mixed_mod, Target target_host); diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 8d35f8b2641c..dabd9084d93f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -27,16 +27,15 @@ from tvm.runtime import Module from tvm.runtime import ndarray from tvm.ir import container -from tvm.ir import CallingConv from tvm.tir import PrimFunc from tvm.ir.module import IRModule -from tvm.ir.transform import PassContext from tvm.target import codegen from tvm.te import tensor from tvm.te import schedule from tvm.target import Target from tvm.tir.buffer import Buffer from tvm.tir.expr import Var +from tvm.driver import _ffi_api as _driver_ffi from . import _ffi_api as ffi @@ -156,8 +155,6 @@ def _build_for_device(input_mod, target, target_host): device_mod : tvm.module A module that contains device code. """ - from tvm.driver import _ffi_api as _driver_ffi - mod_mixed = _driver_ffi.get_mod_mixed(input_mod, target) device_mod = _driver_ffi.get_device_mod(mod_mixed, target) host_mod = _driver_ffi.get_host_mod(mod_mixed, target_host) @@ -291,18 +288,12 @@ def build( mod_host_all = tvm.IRModule({}) - # This is building for target not device though.. - # From here through importing the device modules could probably be consolidated into one C++ function. device_modules = [] for tar, input_mod in target_input_mod.items(): - # mod_host is the module of the host.. bad name. - # Start with moving _build_for_device into c++ mod_host, mdev = _build_for_device(input_mod, tar, target_host) - # what are we updating here? mod_host_all.update(mod_host) device_modules.append(mdev) - print(mod_host_all) # Generate a unified host module. rt_mod_host = codegen.build_module(mod_host_all, target_host) @@ -331,8 +322,6 @@ def build( to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) else: to_return = rt_mod_host - print(target) - print("END OF BUILD") return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) @@ -348,14 +337,11 @@ def from_module(cls, mod, **kwargs): # NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward. # If an exception occurs in cls.__init__, handle will be deleted. For this reason, # set mod.handle to None. - print("from module conv") handle = mod.handle - print(mod.handle) mod.handle = None return cls(handle, **kwargs) def __init__(self, handle, ir_module_by_target=None, name=None): super(OperatorModule, self).__init__(handle) self.ir_module_by_target = ir_module_by_target - print(name) self.name = name From 5008b75008e4a73950a09a188dd146ebefbb9700 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 24 Sep 2021 16:48:52 +0300 Subject: [PATCH 15/35] Move more build functionality into cpp --- python/tvm/driver/build_module.py | 22 ++-------------------- src/driver/driver_api.cc | 23 +++++++++-------------- 2 files changed, 11 insertions(+), 34 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index dabd9084d93f..3a1a0a8ff8d1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -255,6 +255,7 @@ def build( f"but got {type(inputs)}." ) + # starts here if not isinstance(inputs, (dict, container.Map)): target = Target.current() if target is None else target target = target if target else "llvm" @@ -282,27 +283,8 @@ def build( if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - target_input_mod, target_host = Target.check_and_update_host_consist( - target_input_mod, target_host - ) - - mod_host_all = tvm.IRModule({}) - - device_modules = [] - for tar, input_mod in target_input_mod.items(): - mod_host, mdev = _build_for_device(input_mod, tar, target_host) - mod_host_all.update(mod_host) - device_modules.append(mdev) - - # Generate a unified host module. - rt_mod_host = codegen.build_module(mod_host_all, target_host) - - # Import all modules. - for mdev in device_modules: - if mdev: - rt_mod_host.import_module(mdev) + rt_mod_host = _driver_ffi.finalize_module(target_input_mod, target_host) - # stop moving to C++ here. if not isinstance(target_host, Target): target_host = Target(target_host) if ( diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7bf3b7f5e671..0ca6c671f686 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -447,16 +447,20 @@ TVM_REGISTER_GLOBAL("driver.host_target_mangling") return TargetTypeMangling(inputs_arg, target, target_host_arg).second; }); -runtime::Module finalizeModule(const Map& inputs_arg, Target host_target) { +runtime::Module FinalizeModule(const Map& inputs_arg, const Target& host_target) { std::vector device_modules; + Map inputs = inputs_arg; + Target target_host = host_target; + + CheckAndUpdateHostConsistency(&inputs, &target_host); IRModule mhost_all = IRModule(Map()); ICHECK(mhost_all.defined()) << "The host module must be defined"; - for (const auto& it : inputs_arg) { + for (const auto& it : inputs) { if (it.second.defined()) { - auto pair = SplitFuncsToDevHostMods(it.second, it.first, host_target); + auto pair = SplitFuncsToDevHostMods(it.second, it.first, target_host); auto& host_mod = pair.first; auto& device_mod = pair.second; @@ -471,7 +475,7 @@ runtime::Module finalizeModule(const Map& inputs_arg, Target h } } } - runtime::Module complete_mod = codegen::Codegen(mhost_all, host_target); + runtime::Module complete_mod = codegen::Codegen(mhost_all, target_host); // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { @@ -483,7 +487,7 @@ runtime::Module finalizeModule(const Map& inputs_arg, Target h TVM_REGISTER_GLOBAL("driver.finalize_module") .set_body_typed([](const Map& inputs_arg, Target host_target) { - return finalizeModule(inputs_arg, host_target); + return FinalizeModule(inputs_arg, host_target); }); // Can we make this take one annotated IRModule? @@ -599,15 +603,6 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { // Python annotates all functions in the mod with the Target passed in here; I think we // shouldn't have to do that. - printf("\n"); - printf("***************************** \n"); - printf("%d\n", is_entry_func); - printf("***************************** \n"); - - // if (is_entry_func) { - // mixed_mod = WithAttr(std::move(mixed_mod), "tir.is_entry_func", Bool(true)); - // } - bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); if (detect_global_barrier) { From f73791ad8826e32692948bb1e6b6b6ad649a5aa4 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 24 Sep 2021 18:10:27 +0300 Subject: [PATCH 16/35] Remove comments --- src/driver/driver_api.cc | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0ca6c671f686..c001be6d80d5 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -454,6 +454,10 @@ runtime::Module FinalizeModule(const Map& inputs_arg, const Ta CheckAndUpdateHostConsistency(&inputs, &target_host); + if (!target_host.defined()) { + target_host = DefaultTargetHost(target_host); + } + IRModule mhost_all = IRModule(Map()); ICHECK(mhost_all.defined()) << "The host module must be defined"; @@ -490,11 +494,6 @@ TVM_REGISTER_GLOBAL("driver.finalize_module") return FinalizeModule(inputs_arg, host_target); }); -// Can we make this take one annotated IRModule? -// -// Build for heterogeneous execution. -// -// It looks like this version of build doesn't lower, unlike the python version.... runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); @@ -582,8 +581,6 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return build(inputs, target_host); } -// Gets the "mixed_module" from python driver/build_module.py's build function. -// Honestly not really sure what this actually is. IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); @@ -600,17 +597,12 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - // Python annotates all functions in the mod with the Target passed in here; I think we - // shouldn't have to do that. - bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); if (detect_global_barrier) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } - // mixed_mod->GetAttr - mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); From ba98e6fcbebf41dd1b61b7eef73910554032bc0b Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 24 Sep 2021 18:38:25 +0300 Subject: [PATCH 17/35] Remove unused defs and imports --- python/tvm/driver/build_module.py | 50 +------------------------------ 1 file changed, 1 insertion(+), 49 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 3a1a0a8ff8d1..483244ad6b32 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -20,7 +20,6 @@ """ from typing import Union, Optional, List, Mapping -import warnings import tvm.tir @@ -29,7 +28,6 @@ from tvm.ir import container from tvm.tir import PrimFunc from tvm.ir.module import IRModule -from tvm.target import codegen from tvm.te import tensor from tvm.te import schedule from tvm.target import Target @@ -132,49 +130,6 @@ def lower( raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) -def _build_for_device(input_mod, target, target_host): - """Build the lowered functions for a device with the given compilation - target. - - Parameters - ---------- - input_mod : IRModule - The schedule to be built. - - target : str or :any:`tvm.target.Target` - The target and option of the compilation. - - target_host : str or :any:`tvm.target.Target` - The host compilation target. - - Returns - ------- - host_mod : IRModule - The host IRModule. - - device_mod : tvm.module - A module that contains device code. - """ - mod_mixed = _driver_ffi.get_mod_mixed(input_mod, target) - device_mod = _driver_ffi.get_device_mod(mod_mixed, target) - host_mod = _driver_ffi.get_host_mod(mod_mixed, target_host) - - device_type = ndarray.device(target.kind.name, 0).device_type - if device_type == ndarray.cpu(0).device_type and target_host == target: - assert len(device_mod.functions) == 0 - if "gpu" in target.keys and len(device_mod.functions) == 0: - warnings.warn( - "Specified target %s, but cannot find device code, did you do " "bind?" % target - ) - - # rt_mod_dev is runtime::Module so this can be moved out maybe? - rt_mod_dev = ( - codegen.build_module(device_mod, target) if len(device_mod.functions) != 0 else None - ) - # TIR module for the host, runtime module for devices? - return host_mod, rt_mod_dev - - def build( inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -200,7 +155,7 @@ def build( setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + otherwise a stackvm i nterpreter is used. name : Optional[str] The name of result function. binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] @@ -308,9 +263,6 @@ def build( return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) -# What is OperatorModule and how is it different from runtime::Module - - class OperatorModule(Module): """Wraps the Module returned by tvm.build() and captures additional outputs of that function.""" From f515c6fca039d4d91cfc7b94a351de4654791476 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Sat, 25 Sep 2021 11:44:56 +0300 Subject: [PATCH 18/35] Address PR comments --- apps/ios_rpc/README.md | 4 ++-- python/tvm/driver/build_module.py | 22 +++++++++++++--------- python/tvm/relay/backend/compile_engine.py | 18 ++++++++++-------- python/tvm/target/codegen.py | 5 ++--- src/driver/driver_api.cc | 1 - src/relay/backend/aot_executor_codegen.cc | 1 - src/relay/backend/build_module.cc | 3 --- src/relay/backend/te_compiler_cache.cc | 3 --- src/relay/backend/utils.h | 1 - src/runtime/module.cc | 2 -- 10 files changed, 27 insertions(+), 33 deletions(-) diff --git a/apps/ios_rpc/README.md b/apps/ios_rpc/README.md index 2d9cc52dc0ad..c268d15d0179 100644 --- a/apps/ios_rpc/README.md +++ b/apps/ios_rpc/README.md @@ -79,7 +79,7 @@ You can get value of your `team_id` in the following ways: select target `tvmrpc`. At the bottom of this panel go to `Signing & Capabilities` tab and in the field `Team` select your local developer profile (`Your Name (Personal Team)`). - + On the first run of the application you may see message `Could not launch "tvmrpc"` in the XCode and message `Untrusted Developer` on your device. In this case it will be necessary to check the certificate. Open @@ -210,7 +210,7 @@ model and execute it on the target device. For this purpose we will use ```shell python3 tests/ios_rpc_test.py --host --port 9190 --mode "tracker" ``` -The output will be the same as in section +The output will be the same as in section [Standalone RPC](#standalone-rpc). ## Communication without Wi-Fi and speed up in case of slow Wi-Fi diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 483244ad6b32..51bf36a292e5 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -120,14 +120,15 @@ def lower( m : IRModule The result IRModule """ - # ffi.relay.lower_te_pass() + # TODO(@mikepapadim) introduce ffi.relay.lower_te_pass() if isinstance(inp, IRModule): return ffi.lower_module(inp, simple_mode) if isinstance(inp, PrimFunc): return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError( + "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def build( @@ -155,7 +156,7 @@ def build( setup the dimensions and parameters correctly. target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, - otherwise a stackvm i nterpreter is used. + otherwise a stackvm interpreter is used. name : Optional[str] The name of result function. binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] @@ -195,9 +196,8 @@ def build( See the note on :any:`tvm.target` on target string format. """ - # Lowering if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)): - # should this be te_lower instead? + # TODO(@mikepapadim) replace with te_lower input_mod = lower(inputs, args, name=name, binds=binds) elif isinstance(inputs, (list, tuple, container.Array)): merged_mod = tvm.IRModule({}) @@ -220,9 +220,11 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError( + "The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError( + "inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -250,13 +252,15 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module( + [rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module( + [rt_mod_host], target_host) else: to_return = rt_mod_host diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 071d9e6973d9..97357ffd5384 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -42,7 +42,8 @@ class LoweredOutput(Object): """Lowered output""" def __init__(self, outputs, implement): - self.__init_handle_by_constructor__(_backend._make_LoweredOutput, outputs, implement) + self.__init_handle_by_constructor__( + _backend._make_LoweredOutput, outputs, implement) @tvm._ffi.register_object("relay.CCacheKey") @@ -59,7 +60,8 @@ class CCacheKey(Object): """ def __init__(self, source_func, target): - self.__init_handle_by_constructor__(_backend._make_CCacheKey, source_func, target) + self.__init_handle_by_constructor__( + _backend._make_CCacheKey, source_func, target) @tvm._ffi.register_object("relay.CCacheValue") @@ -228,7 +230,8 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) if cfg.is_fallback: # Skip fallback config continue - logger.info("Implementation %s for %s has cost %.2e", impl.name, op.name, cfg.cost) + logger.info("Implementation %s for %s has cost %.2e", + impl.name, op.name, cfg.cost) if best_cfg is None or best_cfg.cost > cfg.cost: best_autotvm_impl = impl best_cfg = cfg @@ -275,9 +278,6 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) return best_plevel_impl, outputs[best_plevel_impl] -# Returns LoweredOutput - - @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" @@ -294,7 +294,8 @@ def lower_call(call, inputs, target): new_fields = [] for field in ret_type.fields: if isinstance(field, _ty.TensorType): - new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype)) + new_fields.append(_ty.TensorType( + get_shape(field.shape), field.dtype)) else: new_fields.append(field) ret_type = _ty.TupleType(new_fields) @@ -312,7 +313,8 @@ def lower_call(call, inputs, target): reenable_tracing = True if not is_dyn: - best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target) + best_impl, outputs = select_implementation( + op, call.attrs, inputs, ret_type, target) else: # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes. best_impl, outputs = select_implementation( diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 078c7995a65d..1558a13af85f 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -35,8 +35,6 @@ def build_module(mod, target): module : runtime.Module The corressponding module. """ - print("In codegen build module") - # Where is Build defined? can't find it, can only find target.build.something... target = Target(target) if isinstance(target, str) else target return _ffi_api.Codegen(mod, target) @@ -76,4 +74,5 @@ def llvm_version_major(allow_none=False): except AttributeError: if allow_none: return None - raise RuntimeError("LLVM version is not available, please check if you build with LLVM") + raise RuntimeError( + "LLVM version is not available, please check if you build with LLVM") diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index c001be6d80d5..9da72b25bca9 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -381,7 +381,6 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); }); -// Splits module into one to run on the device and one to run the host. E.g., CUDA, OpenCL etc std::pair SplitFuncsToDevHostMods(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg) { Target target = target_arg, target_host = target_host_arg; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 1eb3d36516a2..d1191064cfe1 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -740,7 +740,6 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { *rv = get_param_id(key); }); } else if (name == "get_irmodule") { - // OK here is one get_irmodule! return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); } else if (name == "get_external_modules") { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 43393efec569..69dced36295e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -176,7 +176,6 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { - // OH this must be where the self.build = mod["build"] comes from! return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 5); this->Build(args[0], args[1], args[2], args[3], args[4]); @@ -489,11 +488,9 @@ class RelayBuildModule : public runtime::ModuleNode { executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); - // Another Map auto lowered_funcs = executor_codegen_->GetIRModule(); // No need to build for external functions. - // TODO(Ext_dev shouldn't be passed in in this module I think so eventually so we can nuke it) Target ext_dev("ext_dev"); if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) { lowered_funcs.Set(ext_dev, IRModule()); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 53803b7cf9e2..06b5dae89f5a 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -54,7 +54,6 @@ TVM_REGISTER_NODE_TYPE(CachedFuncNode); TVM_REGISTER_NODE_TYPE(CCacheKeyNode); TVM_REGISTER_NODE_TYPE(CCacheValueNode); -// LoweredOutput constructor! LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { auto n = make_object(); n->outputs = std::move(outputs); @@ -225,7 +224,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator Array VisitExpr_(const CallNode* call_node) final { static auto fpattern = Op::GetAttrMap("TOpPattern"); - // So this is the PYTHON version not the C++ version defined in relay_build_module_test.cc static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -256,7 +254,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator const auto* copy_input = inputs[0].operator->(); outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { - // Right so we need to change lower_call LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); outputs = lowered_out->outputs; impl = lowered_out->implementation; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 154dc359b4bf..f8ff20ece561 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -138,7 +138,6 @@ class FunctionInfo : public ObjectRef { */ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); -// LoweredOutput definition -- why is this redefined in python??? /*! * \brief Executor generator artifacts. Those artifacts are subsequently * used by the relay build process. diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 31b65ade6f43..f9c281ab9d02 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -63,12 +63,10 @@ void ModuleNode::Import(Module other) { PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) { ModuleNode* self = this; - // This must be the REAL GetFunction but IDK where it actually lives!! PackedFunc pf = self->GetFunction(name, GetObjectPtr(this)); if (pf != nullptr) return pf; if (query_imports) { for (Module& m : self->imports_) { - // Where is this defined? pf = m.operator->()->GetFunction(name, query_imports); if (pf != nullptr) { return pf; From 6b366c3cd8cdda89d5d9aa9daed662fb2d17279e Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Sat, 25 Sep 2021 11:51:20 +0300 Subject: [PATCH 19/35] More PR comments --- src/relay/backend/aot_executor_codegen.cc | 8 ++++++-- src/relay/backend/compile_engine.cc | 1 - src/relay/backend/te_compiler_cache.cc | 1 - 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index d1191064cfe1..deca3b5a4c5a 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -559,7 +559,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { target_host_(target_host), use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))) {} - // Returns LoweredOutput LoweredOutput Codegen(relay::Function func, String mod_name) { auto aot_allocator = AOTOnDemandAllocator(); aot_allocator.Run(func); @@ -626,8 +625,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { for (auto sid : kv.second->storage_ids) { + // The buffer_var is created with storage_scope to be global.workspace to be serviced by + // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor + // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and + // should not be lowered to the stack. For more details please refer to the discussion here: + // https://github.com/apache/tvm/issues/9022 te::Var buffer_var(MakeString("sid_", sid), - PointerType(PrimType(DataType::Int(8)), "global")); + PointerType(PrimType(DataType::Int(8)), "global.workspace")); sids_table_[sid] = buffer_var; } } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 7e844bc11c28..0e7af2278375 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -282,7 +282,6 @@ CompileEngine& CompileEngine::Global() { TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool); -// Make LoweredOutput TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") .set_body_typed([](tvm::Array outputs, OpImplementation impl) { return LoweredOutput(outputs, impl); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 06b5dae89f5a..d0e83765928a 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -440,7 +440,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; - // Huh why are we lowering the schedule here?? Seems weird. IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, From 1e24b25a804a7a49157f9cac0aa74a6c0feada33 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Sat, 25 Sep 2021 11:59:13 +0300 Subject: [PATCH 20/35] More comments --- python/tvm/driver/build_module.py | 15 +++++---------- python/tvm/relay/backend/compile_engine.py | 15 +++++---------- python/tvm/target/codegen.py | 3 +-- python/tvm/target/target.py | 1 - src/driver/driver_api.cc | 5 ----- 5 files changed, 11 insertions(+), 28 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 51bf36a292e5..d5442c0d1efe 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -127,8 +127,7 @@ def lower( return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError( - "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def build( @@ -220,11 +219,9 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError( - "The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError( - "inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -252,15 +249,13 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) else: to_return = rt_mod_host diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 97357ffd5384..e9129db7b200 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -42,8 +42,7 @@ class LoweredOutput(Object): """Lowered output""" def __init__(self, outputs, implement): - self.__init_handle_by_constructor__( - _backend._make_LoweredOutput, outputs, implement) + self.__init_handle_by_constructor__(_backend._make_LoweredOutput, outputs, implement) @tvm._ffi.register_object("relay.CCacheKey") @@ -60,8 +59,7 @@ class CCacheKey(Object): """ def __init__(self, source_func, target): - self.__init_handle_by_constructor__( - _backend._make_CCacheKey, source_func, target) + self.__init_handle_by_constructor__(_backend._make_CCacheKey, source_func, target) @tvm._ffi.register_object("relay.CCacheValue") @@ -230,8 +228,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) if cfg.is_fallback: # Skip fallback config continue - logger.info("Implementation %s for %s has cost %.2e", - impl.name, op.name, cfg.cost) + logger.info("Implementation %s for %s has cost %.2e", impl.name, op.name, cfg.cost) if best_cfg is None or best_cfg.cost > cfg.cost: best_autotvm_impl = impl best_cfg = cfg @@ -294,8 +291,7 @@ def lower_call(call, inputs, target): new_fields = [] for field in ret_type.fields: if isinstance(field, _ty.TensorType): - new_fields.append(_ty.TensorType( - get_shape(field.shape), field.dtype)) + new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype)) else: new_fields.append(field) ret_type = _ty.TupleType(new_fields) @@ -313,8 +309,7 @@ def lower_call(call, inputs, target): reenable_tracing = True if not is_dyn: - best_impl, outputs = select_implementation( - op, call.attrs, inputs, ret_type, target) + best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target) else: # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes. best_impl, outputs = select_implementation( diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 1558a13af85f..ec047e4f7a28 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -74,5 +74,4 @@ def llvm_version_major(allow_none=False): except AttributeError: if allow_none: return None - raise RuntimeError( - "LLVM version is not available, please check if you build with LLVM") + raise RuntimeError("LLVM version is not available, please check if you build with LLVM") diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index dc85c425c22a..af2f5d857293 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -174,7 +174,6 @@ def list_kinds(): """Returns the list of available target names.""" return list(_ffi_api.ListTargetKinds()) - # TODO: make this return IRModule? idk it seems @staticmethod def check_and_update_host_consist(target, host=None, target_is_dict_key=True): """A helper function that merges a legacy "target, target_host" pair, then returns diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 9da72b25bca9..fdb1764ba89a 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -479,7 +479,6 @@ runtime::Module FinalizeModule(const Map& inputs_arg, const Ta } } runtime::Module complete_mod = codegen::Codegen(mhost_all, target_host); - // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { complete_mod.Import(it); @@ -501,7 +500,6 @@ runtime::Module build(const Map& inputs_arg, const Target& tar Target target_host = target_host_arg; // Fetch previous defined target host in targets - // this is redefined in python ahh CheckAndUpdateHostConsistency(&inputs, &target_host); if (!target_host.defined()) { @@ -543,7 +541,6 @@ runtime::Module build(const Map& inputs_arg, const Target& tar } runtime::Module mhost = codegen::Codegen(mhost_all, target_host); - // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { mhost.Import(it); @@ -569,8 +566,6 @@ runtime::Module build(const Map& inputs_arg, const Target& tar } // Build for homogeneous execution. -// Where is this called from?] -// called from compile engine and it accepts lowered functions runtime::Module build(const IRModule& funcs, const Target& target_arg, const Target& target_host_arg) { auto target = target_arg, target_host = target_host_arg; From af8c8e303c2ff69476ee5549055d18a987a5ad87 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Sat, 25 Sep 2021 12:02:03 +0300 Subject: [PATCH 21/35] More comments --- python/tvm/ir/function.py | 1 - python/tvm/relay/build_module.py | 2 -- python/tvm/target/codegen.py | 1 - src/relay/backend/te_compiler.h | 1 - 4 files changed, 5 deletions(-) diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index c879935b5011..c3f1bf5f562a 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -22,7 +22,6 @@ from . import _ffi_api -# Python CallingConv class CallingConv(IntEnum): """Possible kinds of calling conventions.""" diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 905b8e2d9f9d..f1686d2a03bb 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -282,8 +282,6 @@ def get_executor_from_target(target, target_host): return executor -# Which build is this one... Relay --> graph executor -# can params being parsed during run-time? def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"): # fmt: off # pylint: disable=line-too-long diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index ec047e4f7a28..862366d0c082 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -36,7 +36,6 @@ def build_module(mod, target): The corressponding module. """ target = Target(target) if isinstance(target, str) else target - return _ffi_api.Codegen(mod, target) diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index b72e2cb6dca7..d5135e6301c4 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -43,7 +43,6 @@ #include #include #include -#include #include #include From 57b8039784f990f2764e3314821af4d4f42f2bdc Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Sun, 26 Sep 2021 11:09:03 +0300 Subject: [PATCH 22/35] Add comments on the new split function --- src/driver/driver_api.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index fdb1764ba89a..880ef94ea43b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -158,7 +158,7 @@ transform::Pass BindTarget(Target target) { transform::Pass AnnotateEntryFunc(bool b) { auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); + return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(b)); }; return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {}); } @@ -381,6 +381,11 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode); }); +/** + * This function takes the input module that contains both the device and host opts. + * Then, it applies transformation on the original module before splitting into separate modules for + * device and host. Then it also applies transformations on the new splitted modules. + */ std::pair SplitFuncsToDevHostMods(IRModule mod_mixed, const Target& target_arg, const Target& target_host_arg) { Target target = target_arg, target_host = target_host_arg; From 0c4bf6ddb5c97ebb332e69be97bb3a3cdee50743 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Mon, 4 Oct 2021 15:44:18 +0300 Subject: [PATCH 23/35] Fix PR comments on clarity --- include/tvm/target/codegen.h | 2 +- src/driver/driver_api.cc | 18 +++++++++--------- src/target/codegen.cc | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 0c82193c427a..b2cab0e4bc45 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -45,7 +45,7 @@ using runtime::TVMRetValue; * \param target The target to be built. * \return The result runtime::Module. */ -runtime::Module Codegen(IRModule mod, Target target); +runtime::Module Build(IRModule mod, Target target); /*! * \brief Pack imported device library to a C file. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 880ef94ea43b..774e662d8405 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -156,7 +156,7 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); } -transform::Pass AnnotateEntryFunc(bool b) { +static transform::Pass AnnotateEntryFunc(bool b) { auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(b)); }; @@ -386,8 +386,8 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") * Then, it applies transformation on the original module before splitting into separate modules for * device and host. Then it also applies transformations on the new splitted modules. */ -std::pair SplitFuncsToDevHostMods(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg) { +std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, + const Target& target_host_arg) { Target target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); @@ -468,7 +468,7 @@ runtime::Module FinalizeModule(const Map& inputs_arg, const Ta for (const auto& it : inputs) { if (it.second.defined()) { - auto pair = SplitFuncsToDevHostMods(it.second, it.first, target_host); + auto pair = SplitMixedModule(it.second, it.first, target_host); auto& host_mod = pair.first; auto& device_mod = pair.second; @@ -479,11 +479,11 @@ runtime::Module FinalizeModule(const Map& inputs_arg, const Ta mhost_all->Update(host_mod); if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Codegen(device_mod, it.first)); + device_modules.push_back(codegen::Build(device_mod, it.first)); } } } - runtime::Module complete_mod = codegen::Codegen(mhost_all, target_host); + runtime::Module complete_mod = codegen::Build(mhost_all, target_host); for (const auto& it : device_modules) { if (it.operator->()) { complete_mod.Import(it); @@ -529,7 +529,7 @@ runtime::Module build(const Map& inputs_arg, const Target& tar for (const auto& it : inputs) { if (it.second.defined()) { - auto pair = SplitFuncsToDevHostMods(it.second, it.first, target_host); + auto pair = SplitMixedModule(it.second, it.first, target_host); auto& host_mod = pair.first; auto& device_mod = pair.second; @@ -540,12 +540,12 @@ runtime::Module build(const Map& inputs_arg, const Target& tar mhost_all->Update(host_mod); if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Codegen(device_mod, it.first)); + device_modules.push_back(codegen::Build(device_mod, it.first)); } } } - runtime::Module mhost = codegen::Codegen(mhost_all, target_host); + runtime::Module mhost = codegen::Build(mhost_all, target_host); for (const auto& it : device_modules) { if (it.operator->()) { mhost.Import(it); diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 4ec17adc5bd4..5a4aa39f01b4 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -40,7 +40,7 @@ namespace tvm { namespace codegen { -runtime::Module Codegen(IRModule mod, Target target) { +runtime::Module Build(IRModule mod, Target target) { if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) .value()) { @@ -311,7 +311,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, return (*codegen_f)(blob_byte_array, system_lib, target_triple); } -TVM_REGISTER_GLOBAL("target.Codegen").set_body_typed(Codegen); +TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build); // Export two auxiliary function to the runtime namespace. TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); From 41cf6f3c8bc8eb04b32d4b520849f36407cd6b74 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Thu, 7 Oct 2021 14:46:31 +0300 Subject: [PATCH 24/35] Test CI --- python/tvm/driver/build_module.py | 80 +++++++++++++++++++++++++++---- src/driver/driver_api.cc | 31 +++++++++--- 2 files changed, 95 insertions(+), 16 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index d5442c0d1efe..d7fb1c143cb5 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -127,7 +127,8 @@ def lower( return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError( + "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def build( @@ -142,7 +143,8 @@ def build( for devices coupled with target information. Parameters ---------- - inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] + inputs : Union[tvm.te.schedule.Schedule, + tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] The input to be built args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] The argument lists to the function. @@ -195,20 +197,66 @@ def build( See the note on :any:`tvm.target` on target string format. """ - if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)): - # TODO(@mikepapadim) replace with te_lower + # if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)): + # # TODO(@mikepapadim) replace with te_lower + # input_mod = lower(inputs, args, name=name, binds=binds) + # elif isinstance(inputs, (list, tuple, container.Array)): + # merged_mod = tvm.IRModule({}) + # for x in inputs: + # merged_mod.update(lower(x)) + # input_mod = merged_mod + # elif not isinstance(inputs, (dict, container.Map)): + # raise ValueError( + # f"Inputs must be Schedule, PrimFunc, IRModule or dict of target to IRModule, " + # f"but got {type(inputs)}." + # ) + + # # starts here + # if not isinstance(inputs, (dict, container.Map)): + # target = Target.current() if target is None else target + # target = target if target else "llvm" + # target_input_mod = {target: input_mod} + # else: + # target_input_mod = inputs + + # for tar, mod in target_input_mod.items(): + # if not isinstance(tar, (str, Target)): + # raise ValueError( + # "The key of inputs must be str or " "Target when inputs is dict.") + # if not isinstance(mod, tvm.IRModule): + # raise ValueError( + # "inputs must be Schedule, IRModule," "or dict of str to IRModule.") + + # target_input_mod, target_host = Target.check_and_update_host_consist( + # target_input_mod, target_host + # ) + + # if not target_host: + # for tar, mod in target_input_mod.items(): + # tar = Target(tar) + # device_type = ndarray.device(tar.kind.name, 0).device_type + # if device_type == ndarray.cpu(0).device_type: + # target_host = tar + # break + # if not target_host: + # target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + + if isinstance(inputs, schedule.Schedule): + if args is None: + raise ValueError("args must be given for build from schedule") input_mod = lower(inputs, args, name=name, binds=binds) elif isinstance(inputs, (list, tuple, container.Array)): merged_mod = tvm.IRModule({}) for x in inputs: merged_mod.update(lower(x)) input_mod = merged_mod + elif isinstance(inputs, (tvm.IRModule, PrimFunc)): + input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( - f"Inputs must be Schedule, PrimFunc, IRModule or dict of target to IRModule, " + f"Inputs must be Schedule, IRModule or dict of target to IRModule, " f"but got {type(inputs)}." ) - # starts here if not isinstance(inputs, (dict, container.Map)): target = Target.current() if target is None else target @@ -219,9 +267,11 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError( + "The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError( + "inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -237,8 +287,16 @@ def build( if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + target_input_mod, target_host = Target.check_and_update_host_consist( + target_input_mod, target_host + ) + rt_mod_host = _driver_ffi.finalize_module(target_input_mod, target_host) + target_input_mod, target_host = Target.check_and_update_host_consist( + target_input_mod, target_host + ) + if not isinstance(target_host, Target): target_host = Target(target_host) if ( @@ -249,13 +307,15 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module( + [rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module( + [rt_mod_host], target_host) else: to_return = rt_mod_host diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 774e662d8405..4889c19be679 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -393,6 +393,9 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& ICHECK(mod_mixed.defined()) << "This module must be defined"; + VLOG_CONTEXT << target->str(); + VLOG(0) << "Executing module pass with opt level: "; + mod_mixed = OptimizeMixedModule(mod_mixed, target); auto host_mod = OptimizeHostModule(mod_mixed, target_host); @@ -401,10 +404,13 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& // some final misc checks. auto keys = target->GetKeys(); + + // CheckAndUpdateHostConsistency(&target, &target_host); + bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); if (target_is_gpu && device_mod->functions.size() == 0) { - LOG(WARNING) << "Specified target " << target->str() - << " but cannot find device code. Did you forget to bind?"; + DLOG(WARNING) << "Specified target " << target->str() + << " but cannot find device code. Did you forget to bind?"; } return {host_mod, device_mod}; @@ -458,10 +464,22 @@ runtime::Module FinalizeModule(const Map& inputs_arg, const Ta CheckAndUpdateHostConsistency(&inputs, &target_host); + if (!target_host.defined()) { + for (const auto& it : inputs) { + if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) { + target_host = it.first; + break; + } + } + } + if (!target_host.defined()) { target_host = DefaultTargetHost(target_host); } + // Update target host for all targets + CheckAndUpdateHostConsistency(&inputs, &target_host); + IRModule mhost_all = IRModule(Map()); ICHECK(mhost_all.defined()) << "The host module must be defined"; @@ -483,6 +501,7 @@ runtime::Module FinalizeModule(const Map& inputs_arg, const Ta } } } + runtime::Module complete_mod = codegen::Build(mhost_all, target_host); for (const auto& it : device_modules) { if (it.operator->()) { @@ -587,15 +606,15 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { mixed_pass_list.push_back(BindTarget(target)); + mixed_pass_list.push_back(tir::transform::VerifyMemory()); + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + bool is_entry_func = false; if (mixed_mod->functions.size() == 1) { is_entry_func = pass_ctx->GetConfig("tir.is_entry_func", Bool(true)).value(); mixed_pass_list.push_back(AnnotateEntryFunc(is_entry_func)); } - mixed_pass_list.push_back(tir::transform::VerifyMemory()); - mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); if (detect_global_barrier) { @@ -663,8 +682,8 @@ IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target) { device_pass_list.push_back(tir::transform::LowerWarpMemory()); device_pass_list.push_back(tir::transform::Simplify()); device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); - device_pass_list.push_back(tir::transform::LowerIntrin()); device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + device_pass_list.push_back(tir::transform::LowerIntrin()); auto device_opt_mod = transform::Sequential(device_pass_list); From b7f27d08b5b0a63c8516975b55f76a54d9c77c7d Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Thu, 7 Oct 2021 15:28:12 +0300 Subject: [PATCH 25/35] Fix format --- python/tvm/driver/build_module.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index d7fb1c143cb5..adcc427ab5e3 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -127,8 +127,7 @@ def lower( return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError( - "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def build( @@ -267,11 +266,9 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError( - "The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError( - "inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -307,15 +304,13 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) else: to_return = rt_mod_host From e1658b595d17ca0ade8165b6a0c764ab523fbf8d Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 8 Oct 2021 12:10:09 +0300 Subject: [PATCH 26/35] Refactor build --- include/tvm/driver/driver_api.h | 9 ++- python/tvm/driver/build_module.py | 44 --------------- src/driver/driver_api.cc | 92 +++++-------------------------- 3 files changed, 19 insertions(+), 126 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index bf136a273887..3849b6922ba5 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -43,15 +43,14 @@ #include namespace tvm { - /*! * \brief Returns the optimized IRModule for original fused module (pre split) that contains device * and host code. * \param mixed_mod The original mixed module. * \param target The device Target. * \return The result optimized mixed module. - */ -IRModule OptimizeMixedModule(IRModule mixed_mod, Target target); +// */ +IRModule MixedModulePassManager(IRModule mixed_mod, Target target); /*! * \brief Returns the optimized IRModule for the device Target after device/host from mixed module. @@ -59,7 +58,7 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target); * \param target The device Target. * \return The result optimized device module. */ -IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target); +IRModule DeviceModulePassManager(IRModule mixed_mod, Target target); /*! * \brief Returns the optimized IRModule for the host Target after device/host from mixed module. @@ -67,7 +66,7 @@ IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target); * \param target_host The host Target. * \return The result optimized host module. */ -IRModule OptimizeHostModule(IRModule mixed_mod, Target target_host); +IRModule HostModulePassManager(IRModule mixed_mod, Target target_host); /*! * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index adcc427ab5e3..c3b6aec63e2c 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -196,50 +196,6 @@ def build( See the note on :any:`tvm.target` on target string format. """ - # if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)): - # # TODO(@mikepapadim) replace with te_lower - # input_mod = lower(inputs, args, name=name, binds=binds) - # elif isinstance(inputs, (list, tuple, container.Array)): - # merged_mod = tvm.IRModule({}) - # for x in inputs: - # merged_mod.update(lower(x)) - # input_mod = merged_mod - # elif not isinstance(inputs, (dict, container.Map)): - # raise ValueError( - # f"Inputs must be Schedule, PrimFunc, IRModule or dict of target to IRModule, " - # f"but got {type(inputs)}." - # ) - - # # starts here - # if not isinstance(inputs, (dict, container.Map)): - # target = Target.current() if target is None else target - # target = target if target else "llvm" - # target_input_mod = {target: input_mod} - # else: - # target_input_mod = inputs - - # for tar, mod in target_input_mod.items(): - # if not isinstance(tar, (str, Target)): - # raise ValueError( - # "The key of inputs must be str or " "Target when inputs is dict.") - # if not isinstance(mod, tvm.IRModule): - # raise ValueError( - # "inputs must be Schedule, IRModule," "or dict of str to IRModule.") - - # target_input_mod, target_host = Target.check_and_update_host_consist( - # target_input_mod, target_host - # ) - - # if not target_host: - # for tar, mod in target_input_mod.items(): - # tar = Target(tar) - # device_type = ndarray.device(tar.kind.name, 0).device_type - # if device_type == ndarray.cpu(0).device_type: - # target_host = tar - # break - # if not target_host: - # target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - if isinstance(inputs, schedule.Schedule): if args is None: raise ValueError("args must be given for build from schedule") diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 4889c19be679..2116a283377b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -48,6 +48,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using tvm::Array; bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); @@ -393,19 +394,15 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& ICHECK(mod_mixed.defined()) << "This module must be defined"; - VLOG_CONTEXT << target->str(); - VLOG(0) << "Executing module pass with opt level: "; + mod_mixed = MixedModulePassManager(mod_mixed, target); - mod_mixed = OptimizeMixedModule(mod_mixed, target); + IRModule host_mod = HostModulePassManager(mod_mixed, target_host); - auto host_mod = OptimizeHostModule(mod_mixed, target_host); + IRModule device_mod = DeviceModulePassManager(mod_mixed, target); - auto device_mod = OptimizeDeviceModule(mod_mixed, target); - - // some final misc checks. auto keys = target->GetKeys(); - // CheckAndUpdateHostConsistency(&target, &target_host); + CheckAndUpdateHostConsistency(&target, &target_host); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); if (target_is_gpu && device_mod->functions.size() == 0) { @@ -416,47 +413,6 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& return {host_mod, device_mod}; } -std::pair TargetTypeMangling(const Map& inputs_arg, Target target, - Target target_host_arg) { - Target target_input_mod, target_host; - - target = !target.defined() ? target.Current() : target; - - std::vector device_modules; - Map inputs = inputs_arg; - target_host = target_host_arg; - - CheckAndUpdateHostConsistency(&inputs, &target_host); - - if (!target_host.defined()) { - for (const auto& it : inputs) { - if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) { - target_host = it.first; - break; - } - } - } - - if (!target_host.defined()) { - target_host = DefaultTargetHost(target_host); - } - CheckAndUpdateHostConsistency(&inputs, &target_host); - - return {target_input_mod, target_host}; -} - -TVM_REGISTER_GLOBAL("driver.target_mangling") - .set_body_typed([](const Map& inputs_arg, Target target, - Target target_host_arg) { - return TargetTypeMangling(inputs_arg, target, target_host_arg).first; - }); - -TVM_REGISTER_GLOBAL("driver.host_target_mangling") - .set_body_typed([](const Map& inputs_arg, Target target, - Target target_host_arg) { - return TargetTypeMangling(inputs_arg, target, target_host_arg).second; - }); - runtime::Module FinalizeModule(const Map& inputs_arg, const Target& host_target) { std::vector device_modules; Map inputs = inputs_arg; @@ -599,10 +555,10 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return build(inputs, target_host); } -IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { +IRModule MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - Array mixed_pass_list; + Array mixed_pass_list; mixed_pass_list.push_back(BindTarget(target)); @@ -633,29 +589,18 @@ IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) { } mixed_pass_list.push_back(tir::transform::SplitHostDevice()); - auto opt_mixed = transform::Sequential(mixed_pass_list); - return opt_mixed(std::move(mixed_mod)); + return LowerWithPassList(mixed_mod, mixed_pass_list); } -TVM_REGISTER_GLOBAL("driver.get_mod_mixed").set_body_typed([](IRModule mod, Target target) { - return OptimizeMixedModule(mod, target); -}); - -TVM_REGISTER_GLOBAL("driver.get_device_mod").set_body_typed([](IRModule mod, Target target) { - return OptimizeDeviceModule(mod, target); -}); - -TVM_REGISTER_GLOBAL("driver.get_host_mod").set_body_typed([](IRModule mod, Target target_host) { - return OptimizeHostModule(mod, target_host); -}); - -IRModule OptimizeHostModule(IRModule mixed_mod, Target target_host) { - Array host_pass_list; +IRModule HostModulePassManager(IRModule mixed_mod, Target target_host) { + Array host_pass_list; host_pass_list.push_back(Filter([](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; })); + ICHECK(mixed_mod.defined()) << "This module must be defined"; + host_pass_list.push_back(BindTarget(target_host)); host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); @@ -664,13 +609,10 @@ IRModule OptimizeHostModule(IRModule mixed_mod, Target target_host) { host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); host_pass_list.push_back(tir::transform::CombineContextCall()); - auto host_module = transform::Sequential(host_pass_list); - ICHECK(mixed_mod.defined()) << "This module must be defined"; - - return host_module(mixed_mod); + return LowerWithPassList(mixed_mod, host_pass_list); } -IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target) { +IRModule DeviceModulePassManager(IRModule mixed_mod, Target target) { Array device_pass_list; device_pass_list.push_back(Filter([](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == @@ -685,11 +627,7 @@ IRModule OptimizeDeviceModule(IRModule mixed_mod, Target target) { device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); device_pass_list.push_back(tir::transform::LowerIntrin()); - auto device_opt_mod = transform::Sequential(device_pass_list); - - auto mdevice = device_opt_mod(mixed_mod); - - return device_opt_mod(mixed_mod); + return LowerWithPassList(mixed_mod, device_pass_list); } } // namespace tvm From a71a0afca25bedd6ffbeb84319e71154da32a923 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 8 Oct 2021 13:56:17 +0300 Subject: [PATCH 27/35] Expose splitted composite passes to python --- include/tvm/driver/driver_api.h | 26 ++++++++++------- src/driver/driver_api.cc | 49 +++++++++++++++++++++++---------- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 3849b6922ba5..45a938247cc8 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -30,6 +30,7 @@ #define TVM_DRIVER_DRIVER_API_H_ #include +#include #include #include #include @@ -43,30 +44,34 @@ #include namespace tvm { +using tvm::transform::Pass; + /*! - * \brief Returns the optimized IRModule for original fused module (pre split) that contains device - * and host code. + * \brief Configures and returns the composite Pass for the fused module (pre split) that contains + * device and host code. * \param mixed_mod The original mixed module. * \param target The device Target. - * \return The result optimized mixed module. + * \return The composite Pass for the fused module. // */ -IRModule MixedModulePassManager(IRModule mixed_mod, Target target); +TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); /*! - * \brief Returns the optimized IRModule for the device Target after device/host from mixed module. + * \brief Configures and returns the composite Pass for the device Target after device/host from + * mixed module. * \param mixed_mod The optimized mixed module. * \param target The device Target. - * \return The result optimized device module. + * \return The composite Pass for the device module. */ -IRModule DeviceModulePassManager(IRModule mixed_mod, Target target); +TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target); /*! - * \brief Returns the optimized IRModule for the host Target after device/host from mixed module. + * \brief Configures and returns the composite Pass for the host Target after device/host from mixed + * module. * \param mixed_mod The optimized mixed module. * \param target_host The host Target. - * \return The result optimized host module. + * \return The composite Pass for the host module. */ -IRModule HostModulePassManager(IRModule mixed_mod, Target target_host); +TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host); /*! * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) @@ -160,6 +165,7 @@ TVM_DLL runtime::Module build(const Map& input, const Target& * \return The built module that contains code for different processors. */ TVM_DLL runtime::Module build(const Map& input, const Target& target_host); + } // namespace tvm #endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 2116a283377b..3f270ec9c0ae 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -49,6 +49,7 @@ using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; using tvm::Array; +using tvm::transform::Pass; bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); @@ -159,7 +160,7 @@ transform::Pass BindTarget(Target target) { static transform::Pass AnnotateEntryFunc(bool b) { auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(b)); + return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); }; return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {}); } @@ -275,6 +276,11 @@ IRModule LowerWithPassList(IRModule mod, Array pass_list) return mod; } +IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { + mod = seq(std::move(mod)); + return mod; +} + IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds) { // Convert te schedule to IRModule @@ -394,11 +400,11 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& ICHECK(mod_mixed.defined()) << "This module must be defined"; - mod_mixed = MixedModulePassManager(mod_mixed, target); + mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); - IRModule host_mod = HostModulePassManager(mod_mixed, target_host); + IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); - IRModule device_mod = DeviceModulePassManager(mod_mixed, target); + IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); auto keys = target->GetKeys(); @@ -555,20 +561,18 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return build(inputs, target_host); } -IRModule MixedModulePassManager(IRModule mixed_mod, Target target) { +transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - Array mixed_pass_list; + Array mixed_pass_list; mixed_pass_list.push_back(BindTarget(target)); mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - bool is_entry_func = false; if (mixed_mod->functions.size() == 1) { - is_entry_func = pass_ctx->GetConfig("tir.is_entry_func", Bool(true)).value(); - mixed_pass_list.push_back(AnnotateEntryFunc(is_entry_func)); + mixed_pass_list.push_back(AnnotateEntryFunc(true)); } bool detect_global_barrier = @@ -589,10 +593,15 @@ IRModule MixedModulePassManager(IRModule mixed_mod, Target target) { } mixed_pass_list.push_back(tir::transform::SplitHostDevice()); - return LowerWithPassList(mixed_mod, mixed_pass_list); + return transform::Sequential(mixed_pass_list); } -IRModule HostModulePassManager(IRModule mixed_mod, Target target_host) { +TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target) { + return MixedModulePassManager(mixed_mod, target); + }); + +transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { Array host_pass_list; host_pass_list.push_back(Filter([](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != @@ -609,11 +618,16 @@ IRModule HostModulePassManager(IRModule mixed_mod, Target target_host) { host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); host_pass_list.push_back(tir::transform::CombineContextCall()); - return LowerWithPassList(mixed_mod, host_pass_list); + return transform::Sequential(host_pass_list); } -IRModule DeviceModulePassManager(IRModule mixed_mod, Target target) { - Array device_pass_list; +TVM_REGISTER_GLOBAL("driver.host_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target_host) { + return HostModulePassManager(mixed_mod, target_host); + }); + +transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { + Array device_pass_list; device_pass_list.push_back(Filter([](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; @@ -627,7 +641,12 @@ IRModule DeviceModulePassManager(IRModule mixed_mod, Target target) { device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); device_pass_list.push_back(tir::transform::LowerIntrin()); - return LowerWithPassList(mixed_mod, device_pass_list); + return transform::Sequential(device_pass_list); } +TVM_REGISTER_GLOBAL("driver.device_mod_passes") + .set_body_typed([](IRModule mixed_mod, Target target_host) { + return DeviceModulePassManager(mixed_mod, target_host); + }); + } // namespace tvm From 0ca60e9a664cc512a463fb21a8abeb8d1f3abb01 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Fri, 8 Oct 2021 14:26:16 +0300 Subject: [PATCH 28/35] Format files --- python/tvm/driver/build_module.py | 18 +++++++++++------- python/tvm/target/codegen.py | 5 +++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index c3b6aec63e2c..263417a9ae29 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -120,14 +120,14 @@ def lower( m : IRModule The result IRModule """ - # TODO(@mikepapadim) introduce ffi.relay.lower_te_pass() if isinstance(inp, IRModule): return ffi.lower_module(inp, simple_mode) if isinstance(inp, PrimFunc): return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError( + "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def build( @@ -212,7 +212,7 @@ def build( f"Inputs must be Schedule, IRModule or dict of target to IRModule, " f"but got {type(inputs)}." ) - # starts here + if not isinstance(inputs, (dict, container.Map)): target = Target.current() if target is None else target target = target if target else "llvm" @@ -222,9 +222,11 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError( + "The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError( + "inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -260,13 +262,15 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module( + [rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module( + [rt_mod_host], target_host) else: to_return = rt_mod_host diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 862366d0c082..3fa6ad08db6a 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -36,7 +36,7 @@ def build_module(mod, target): The corressponding module. """ target = Target(target) if isinstance(target, str) else target - return _ffi_api.Codegen(mod, target) + return _ffi_api.Build(mod, target) def llvm_lookup_intrinsic_id(name): @@ -73,4 +73,5 @@ def llvm_version_major(allow_none=False): except AttributeError: if allow_none: return None - raise RuntimeError("LLVM version is not available, please check if you build with LLVM") + raise RuntimeError( + "LLVM version is not available, please check if you build with LLVM") From 49cbe5321ebaa005c02b44d75fbe67005124ac1d Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Tue, 12 Oct 2021 15:49:24 +0300 Subject: [PATCH 29/35] Test fix --- python/tvm/driver/build_module.py | 15 +++++---------- python/tvm/target/codegen.py | 3 +-- src/driver/driver_api.cc | 2 +- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 263417a9ae29..61cd0bec25fa 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -126,8 +126,7 @@ def lower( return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError( - "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def build( @@ -222,11 +221,9 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError( - "The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError( - "inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -262,15 +259,13 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) else: to_return = rt_mod_host diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 3fa6ad08db6a..0ab4cb005cb4 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -73,5 +73,4 @@ def llvm_version_major(allow_none=False): except AttributeError: if allow_none: return None - raise RuntimeError( - "LLVM version is not available, please check if you build with LLVM") + raise RuntimeError("LLVM version is not available, please check if you build with LLVM") diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 3f270ec9c0ae..1f52e7c0d79d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -572,7 +572,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); if (mixed_mod->functions.size() == 1) { - mixed_pass_list.push_back(AnnotateEntryFunc(true)); + // mixed_pass_list.push_back(AnnotateEntryFunc(true)); } bool detect_global_barrier = From 6e1203e4633a22144192479dbaef626bb1a7c228 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Tue, 12 Oct 2021 18:44:47 +0300 Subject: [PATCH 30/35] Fix for annotating entry funcs on code targeting CPU --- src/driver/driver_api.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 79d82b15558c..6f6cec9755ca 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -584,8 +584,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - if (mixed_mod->functions.size() == 1) { - // mixed_pass_list.push_back(AnnotateEntryFunc(true)); + if ((mixed_mod->functions.size() == 1) && (target->kind->device_type != kDLCPU)) { + mixed_pass_list.push_back(AnnotateEntryFunc(true)); } bool detect_global_barrier = From 2c0305b38b006785814f1822aab67eec65a6dcb3 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Tue, 12 Oct 2021 20:27:52 +0300 Subject: [PATCH 31/35] Prevent entry funcs to be annotated when compiling for CPU with C runtime enabled --- src/driver/driver_api.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 6f6cec9755ca..801c8c7ad76b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -56,6 +56,12 @@ bool LLVMEnabled() { return pf != nullptr; } +bool ShouldAnnoateEntryFunc(const Target target, const IRModule mod) { + const bool target_c_runtime = (target->GetAttr("runtime").value_or("") == kTvmRuntimeCrt); + const bool single_entry_func = (mod->functions.size() == 1); + return single_entry_func && !target_c_runtime; +} + /*! \return The default host target for a given device target */ Target DefaultTargetHost(Target target) { if (target.defined() && target->kind->device_type == kDLCPU) { @@ -194,7 +200,7 @@ Array CreatePassList(bool disable_loop_partition) { Array user_lower_phase2 = Array(); Array user_lower_phase3 = Array(); - // phase pasees is of the form + // phase passes is of the form // [[phase_number, pass], [phase_number, pass]... ] for (Array phase_pass : add_lower_pass) { const IntImmNode* phase_num = phase_pass[0].as(); @@ -584,7 +590,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - if ((mixed_mod->functions.size() == 1) && (target->kind->device_type != kDLCPU)) { + if (ShouldAnnoateEntryFunc(target, mixed_mod)) { mixed_pass_list.push_back(AnnotateEntryFunc(true)); } From b955799cc773a53e79b7a14cf7e432beda9e0db8 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Tue, 12 Oct 2021 21:37:51 +0300 Subject: [PATCH 32/35] Guard for aot executor entry --- src/driver/driver_api.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 801c8c7ad76b..e659421c23c4 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -56,10 +56,10 @@ bool LLVMEnabled() { return pf != nullptr; } -bool ShouldAnnoateEntryFunc(const Target target, const IRModule mod) { - const bool target_c_runtime = (target->GetAttr("runtime").value_or("") == kTvmRuntimeCrt); +bool ShouldAnnotateEntryFunc(const Target target, const IRModule mod) { + const bool aot_executor = (target->GetAttr("executor").value_or("") == "aot"); const bool single_entry_func = (mod->functions.size() == 1); - return single_entry_func && !target_c_runtime; + return single_entry_func && !aot_executor; } /*! \return The default host target for a given device target */ @@ -590,7 +590,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); - if (ShouldAnnoateEntryFunc(target, mixed_mod)) { + if (ShouldAnnotateEntryFunc(target, mixed_mod)) { mixed_pass_list.push_back(AnnotateEntryFunc(true)); } From aed07650681cd843d77b14812b9ab272fc24bc0c Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Wed, 13 Oct 2021 01:42:14 +0300 Subject: [PATCH 33/35] Sphix format --- python/tvm/driver/build_module.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 61cd0bec25fa..ae598ce5c44d 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -16,8 +16,7 @@ # under the License. # pylint: disable=invalid-name -"""The build utils in python. -""" +"""The build utils in python.""" from typing import Union, Optional, List, Mapping @@ -93,28 +92,22 @@ def lower( simple_mode: bool = False, ) -> IRModule: """Lowering step before build into target. - Parameters ---------- inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule] The TE schedule or TensorIR PrimFunc/IRModule to be built - args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] The argument lists to the function for TE schedule. It should be None if we want to lower TensorIR. - name : str The name of the result function. - binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. - simple_mode : bool Whether only output simple and compact statement, this will skip LoopPartition, api wrapper generation and Unrolling. - Returns ------- m : IRModule @@ -126,7 +119,8 @@ def lower( return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError( + "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def build( @@ -221,9 +215,11 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError( + "The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError( + "inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -259,13 +255,15 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module( + [rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module( + [rt_mod_host], target_host) else: to_return = rt_mod_host From c3505fa8e410ee34707d34fd88c47876baa7ecc2 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Wed, 13 Oct 2021 10:39:09 +0300 Subject: [PATCH 34/35] Sanity fix --- python/tvm/driver/build_module.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index ae598ce5c44d..67a33397cb92 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -119,8 +119,7 @@ def lower( return ffi.lower_primfunc(inp, name, simple_mode) if isinstance(inp, schedule.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError( - "Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) def build( @@ -215,11 +214,9 @@ def build( for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): - raise ValueError( - "The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError( - "inputs must be Schedule, IRModule," "or dict of str to IRModule.") + raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") target_input_mod, target_host = Target.check_and_update_host_consist( target_input_mod, target_host @@ -255,15 +252,13 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module( - [rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) else: to_return = rt_mod_host From 2162c858dea926f12342337735437065245428af Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Wed, 13 Oct 2021 13:03:38 +0300 Subject: [PATCH 35/35] Sphinx fix --- python/tvm/driver/build_module.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 67a33397cb92..429b3e1727cc 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -92,22 +92,28 @@ def lower( simple_mode: bool = False, ) -> IRModule: """Lowering step before build into target. + Parameters ---------- inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule] The TE schedule or TensorIR PrimFunc/IRModule to be built + args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] The argument lists to the function for TE schedule. + It should be None if we want to lower TensorIR. name : str The name of the result function. + binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument. + simple_mode : bool Whether only output simple and compact statement, this will skip LoopPartition, api wrapper generation and Unrolling. + Returns ------- m : IRModule @@ -132,15 +138,19 @@ def build( ): """Build a function with arguments as signature. Code will be generated for devices coupled with target information. + Parameters ---------- inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]] The input to be built + args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] The argument lists to the function. + target : Optional[Union[str, Target]] The target and option of the compilation. + target_host : Optional[Union[str, Target]] Host compilation target, if target is device. When TVM compiles device specific program such as CUDA, @@ -149,21 +159,27 @@ def build( target_host is used to specify the host side codegen target. By default, llvm is used if it is enabled, otherwise a stackvm interpreter is used. + name : Optional[str] The name of result function. + binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] Dictionary that maps the binding of symbolic buffer to Tensor. By default, a new buffer is created for each tensor in the argument. + Returns ------- ret : tvm.module A module that combines both host and device code. + Examples ________ There are two typical example uses of this function depending on the type of the argument `inputs`: 1. it is an IRModule. + .. code-block:: python + n = 2 A = te.placeholder((n,), name='A') B = te.placeholder((n,), name='B') @@ -171,8 +187,11 @@ def build( s = tvm.te.create_schedule(C.op) m = tvm.lower(s, [A, B, C], name="test_add") rt_mod = tvm.build(m, target="llvm") + 2. it is a dict of compilation target to IRModule. + .. code-block:: python + n = 2 A = te.placeholder((n,), name='A') B = te.placeholder((n,), name='B') @@ -183,11 +202,11 @@ def build( m1 = tvm.lower(s1, [A, B, C], name="test_add1") m2 = tvm.lower(s2, [A, B, C], name="test_add2") rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm") + Note ---- See the note on :any:`tvm.target` on target string format. """ - if isinstance(inputs, schedule.Schedule): if args is None: raise ValueError("args must be given for build from schedule")