From cd1318b458420c6ff396e973c35c134f8d35ec1f Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 27 Jun 2022 11:02:07 -0700 Subject: [PATCH] - Masa & Andrew's comments --- include/tvm/ir/module.h | 6 ++-- python/tvm/relay/backend/interpreter.py | 2 +- python/tvm/relay/backend/vm.py | 5 ++-- python/tvm/relay/build_module.py | 28 +++++++++++-------- src/relay/backend/aot_executor_codegen.cc | 16 ++++------- .../contrib/arm_compute_lib/codegen.cc | 2 +- src/relay/backend/contrib/bnns/codegen.cc | 2 +- .../contrib/codegen_json/codegen_json.h | 13 +++++---- src/relay/backend/contrib/cutlass/codegen.cc | 7 +++-- src/relay/backend/contrib/dnnl/codegen.cc | 2 +- src/relay/backend/contrib/tensorrt/codegen.cc | 2 +- .../backend/contrib/verilator/codegen.cc | 2 +- src/relay/backend/graph_executor_codegen.cc | 14 ++++------ src/relay/backend/te_compiler.cc | 3 +- src/relay/backend/vm/compiler.cc | 16 ++++------- src/tir/transforms/extract_constants.cc | 6 ++-- 16 files changed, 61 insertions(+), 65 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 9f2955c098229..f73f2230df4d7 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -519,12 +519,12 @@ constexpr const char* kConstantMemoryPools = "constant_memory_pools"; /* * \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The - * node will record the index into this array. See also kConstNameToNDArray below, which is + * node will record the index into this array. See also kConstNameToConstant below, which is * the analog for Realy Functions. * * Type: Array */ -constexpr const char* kConstantsArray = "Constants"; +constexpr const char* kConstants = "constants"; /*! * \brief All the runtime::Modules accumulated during compilation by external codegen. These @@ -542,7 +542,7 @@ constexpr const char* kExternalMods = "external_mods"; * * Type: Map */ -constexpr const char* kConstNameToNDArray = "const_name_to_ndarray"; +constexpr const char* kConstNameToConstant = "const_name_to_constant"; } // namespace attr } // namespace tvm diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 819e5eda41f58..020736beb5c43 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -195,7 +195,7 @@ class Interpreter(Executor): The runtime device to run the code on. target : tvm.Target - The target option to build the function. + The target option to build the function. Only homogeneous execution is supported. CAUTION: Despite the API the module is prepared upon each call to evaluate rather than once in create_executor. diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index d4a82cd8d4279..bc11d43cb0ca5 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -198,8 +198,9 @@ class VMExecutor(Executor): device : :py:class:`~tvm.runtime.Device` The runtime device to run the code on. - target : :py:class:`Target` - The target option to build the function. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ def __init__(self, mod, device, target): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 27eccb2abacc5..8fe8a86b548b8 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -570,15 +570,16 @@ class GraphExecutor(_interpreter.Executor): device : :py:class:`Device` The runtime device to run the code on. - raw_targets : Array[tvm.target.Target] - The available targets. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ - def __init__(self, mod, device, raw_targets): + def __init__(self, mod, device, target): assert mod is not None self.mod = mod self.device = device - self.raw_targets = raw_targets + self.target = target def _make_executor(self, expr=None): if expr: @@ -589,7 +590,7 @@ def _make_executor(self, expr=None): raise ValueError( "Graph Executor only supports static graphs, got output type", ret_type ) - mod = build(self.mod, target=self.raw_targets) + mod = build(self.mod, target=self.target) gmodule = _graph_executor.GraphModule(mod["default"](self.device)) def _unflatten(flat_iter, cur_type): @@ -630,16 +631,16 @@ class AotExecutor(_interpreter.Executor): device : :py:class:`Device` The runtime device to run the code on. - raw_targets : Array[tvm.target.Target] - The available targets. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ - def __init__(self, mod, device, raw_targets): + def __init__(self, mod, device, target): assert mod is not None self.mod = mod self.device = device - self.raw_targets = raw_targets - assert raw_targets[0].attrs.get("executor", "graph") == "aot" + self.target = target def _make_executor(self, expr=None): if expr: @@ -648,7 +649,7 @@ def _make_executor(self, expr=None): ret_type = self.mod["main"].checked_type.ret_type if _ty.is_dynamic(ret_type): raise ValueError("AOT Executor only supports static graphs, got output type", ret_type) - mod = build(self.mod, target=self.raw_targets) + mod = build(self.mod, target=self.target) # NOTE: Given AOT requires use of the "c" backend, must export/import to compile the # generated code. @@ -722,6 +723,8 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N target : any multi-target like object, see Target.canon_multi_target For homogeneous compilation, the unique build target. For heterogeneous compilation, a dictionary or list of possible build targets. + CAUTION: Though this API allows multiple targets, it does not allow multiple devices, so + heterogenous compilation is not yet supported. params : dict of str to NDArray Input parameters to the graph that do not change @@ -737,11 +740,14 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N if device is not None: assert device.device_type == raw_targets[0].kind.device_type else: + # Use the first target as the device. device = _nd.device(raw_targets[0].kind.device_type, 0) if params is not None: mod = IRModule.from_expr(bind_params_by_name(mod["main"], params)) + assert raw_targets[0].attrs.get("executor") == kind + if kind == "debug": assert len(raw_targets) == 1, "The interpreter currently only supports a single target" return _interpreter.Interpreter(mod, device, raw_targets[0]) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 518a5fcb7d203..ae60970b78af3 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1170,12 +1170,10 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Collect any constants extracted by external codegen. ret.params = std::unordered_map(); - Map const_name_to_ndarray = - lowered_mod - ->GetAttr>(tvm::attr::kConstNameToNDArray, - Map()) - .value(); - for (const auto& kv : const_name_to_ndarray) { + Map const_name_to_constant = + lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or({}); + for (const auto& kv : const_name_to_constant) { ICHECK(ret.params.emplace(kv.first, kv.second).second); } @@ -1223,10 +1221,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { } // Collect any runtime modules generated by external codegen. - ret.external_mods = lowered_mod - ->GetAttr>(tvm::attr::kExternalMods, - Array()) - .value(); + ret.external_mods = + lowered_mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); // This is the point where we separate the functions in the module by target VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod); diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 60d37aa7509da..81a5b5bbd9d8c 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -393,7 +393,7 @@ runtime::Module ACLCompiler(const ObjectRef& ref) { serializer.serialize(); std::string graph_json = serializer.GetJSON(); - // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes // a callback which calls backend::UpdateConstants to capture the map before the function // 'disappears' into lowered form, on the assumption the visit order and thus constant // names match those generated by the JSONSerializer. diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 743525a3c85d1..3791773ad67d6 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -137,7 +137,7 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) { serializer.serialize(); std::string graph_json = serializer.GetJSON(); - // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes // a callback which calls backend::UpdateConstants to capture the map before the function // 'disappears' into lowered form, on the assumption the visit order and thus constant // names match those generated by the JSONSerializer. diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index e6010bcac3b87..de6d0f74061b8 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -169,8 +169,8 @@ class JSONSerializer : public MemoizedExprTranslator& const_name_to_ndarray() const { - return const_name_to_ndarray_; + const std::unordered_map& const_name_to_constant() const { + return const_name_to_constant_; } /*! @@ -260,9 +260,10 @@ class JSONSerializer : public MemoizedExprTranslator VisitExpr_(const ConstantNode* constant_node) { std::string name = symbol_ + "_const_" + std::to_string(const_names_.size()); - VLOG(1) << "Will require parameter '" << name << "' to be available at runtime"; - ICHECK_EQ(const_name_to_ndarray_.count(name), 0); - const_name_to_ndarray_.emplace(name, constant_node->data); + VLOG(1) << "Will require parameter '" << name + << "' to be supplied by the ConstLoaderModule at runtime"; + ICHECK_EQ(const_name_to_constant_.count(name), 0); + const_name_to_constant_.emplace(name, constant_node->data); const_names_.push_back(name); auto node = std::make_shared(name, /*op_type=*/"const"); return AddNode(node, GetRef(constant_node)); @@ -361,7 +362,7 @@ class JSONSerializer : public MemoizedExprTranslator const_name_to_ndarray_; + std::unordered_map const_name_to_constant_; /*! * \brief The domain of the above map, but in order the constants were encountered during * translation. diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index d5e9ab6a7bd7f..bbc6a97a78eeb 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -47,7 +47,9 @@ namespace { Target GetCutlassTarget() { Target target = Target::Current(/*allow_not_defined=*/true); if (!target.defined() || target->kind->name != "cutlass") { - // Use the default CUTLASS compilation options. + // Use the default CUTLASS compilation options if no specific "cutlass" target was given + // in the overall targets list. In that case target_hooks.cc will invoke the custom pass + // without pushing any target instance onto the implicit target stack. target = Target("cutlass"); } return target; @@ -912,8 +914,7 @@ transform::Pass CompileForCutlassImpl() { Target target = GetCutlassTarget(); runtime::Module runtime_mod = (*pf)(mod, target); Array external_mods = - mod->GetAttr>(tvm::attr::kExternalMods, Array()) - .value(); + mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); external_mods.push_back(runtime_mod); return WithAttr(mod, tvm::attr::kExternalMods, external_mods); }; diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index dfe37a3654357..2f47c23a7cf9b 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -586,7 +586,7 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) { serializer.serialize(); std::string graph_json = serializer.GetJSON(); - // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes // a callback which calls backend::UpdateConstants to capture the map before the function // 'disappears' into lowered form, on the assumption the visit order and thus constant // names match those generated by the JSONSerializer. diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index 20a19a36382f1..e08cd240d4d1e 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -319,7 +319,7 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) { std::string graph_json = serializer.GetJSON(); VLOG(1) << "TensorRT JSON:" << std::endl << graph_json; - // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes // a callback which calls backend::UpdateConstants to capture the map before the function // 'disappears' into lowered form, on the assumption the visit order and thus constant // names match those generated by the JSONSerializer. diff --git a/src/relay/backend/contrib/verilator/codegen.cc b/src/relay/backend/contrib/verilator/codegen.cc index ef88266c607f8..2e6fb13263144 100644 --- a/src/relay/backend/contrib/verilator/codegen.cc +++ b/src/relay/backend/contrib/verilator/codegen.cc @@ -112,7 +112,7 @@ runtime::Module VerilatorBackend(const ObjectRef& ref) { serializer.serialize(); std::string graph_json = serializer.GetJSON(); - // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes // a callback which calls backend::UpdateConstants to capture the map before the function // 'disappears' into lowered form, on the assumption the visit order and thus constant // names match those generated by the JSONSerializer. diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 444d56b125727..25238c5e573f0 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -262,18 +262,14 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorGetAttr>(tvm::attr::kExternalMods, Array()) - .value(); + lowered_mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); // Collect any constants extracted by external codegen. ret.params = std::unordered_map(); - Map const_name_to_ndarray = - lowered_mod - ->GetAttr>(tvm::attr::kConstNameToNDArray, - Map()) - .value(); - for (const auto& kv : const_name_to_ndarray) { + Map const_name_to_constant = + lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or(Map()); + for (const auto& kv : const_name_to_constant) { ICHECK(ret.params.emplace(kv.first, kv.second).second); } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index a81d63296c8a0..4390e90b2cf35 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -1196,8 +1196,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // annotate the module with the resulting runtime modules. // TODO(mbs): runtime modules should be first class rather than attributes. Array external_mods = - module->GetAttr>(tvm::attr::kExternalMods, Array()) - .value(); + module->GetAttr>(tvm::attr::kExternalMods).value_or({}); Array new_external_mods = compiler->LowerExternalFunctions(); VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size() << " new external modules"; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 37010d5b4b898..a8bd3df32a90f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1170,24 +1170,20 @@ void VMCompiler::Codegen() { // Retrieve all external runtime modules accumulated by external codegen (both function-at-a-time // and IRModule-at-a-time). Array external_mods = - context_.module - ->GetAttr>(tvm::attr::kExternalMods, Array()) - .value(); + context_.module->GetAttr>(tvm::attr::kExternalMods).value_or({}); // Retrieve any constant bindings accumulated by external codegen (by IRModule-at-a-time passes). - Map const_name_to_ndarray = - context_.module - ->GetAttr>(tvm::attr::kConstNameToNDArray, - Map()) - .value(); + Map const_name_to_constant = + context_.module->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or({}); VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build, " - << external_mods.size() << " external runtime modules, " << const_name_to_ndarray.size() + << external_mods.size() << " external runtime modules, " << const_name_to_constant.size() << " external constants, and " << params_.size() << " local constants"; // Any constant bindings must be merged into the overall 'params' map we've directly accumulated // via the TECompiler callback. - for (const auto& kv : const_name_to_ndarray) { + for (const auto& kv : const_name_to_constant) { ICHECK_EQ(params_.count(kv.first), 0); params_.emplace(kv.first, kv.second); } diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 237f923516dab..f9e620ba3322b 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -80,14 +80,14 @@ tvm::transform::Pass ExtractPrimFuncConstants() { } auto* attrs = m->attrs.CopyOnWrite(); ConstArrayType constant_array_ = - (attrs->dict.count(tvm::attr::kConstantsArray)) - ? Downcast(attrs->dict[tvm::attr::kConstantsArray]) + (attrs->dict.count(tvm::attr::kConstants)) + ? Downcast(attrs->dict[tvm::attr::kConstants]) : ConstArrayType(); Applicator a = Applicator(); func->body = a.Apply(func->body, constant_array_); const ConstArrayType constant_list = a.constant_array_; if (constant_list.size()) { - attrs->dict.Set(tvm::attr::kConstantsArray, constant_list); + attrs->dict.Set(tvm::attr::kConstants, constant_list); } return GetRef(func); };