diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index e32ddb716bd55..715a805187538 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -479,8 +479,10 @@ TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, namespace attr { +// Following are attributes for IRModule only. + /*! - * \brief Executor targetted by the module + * \brief Executor targeted by the module * * Type: Executor * @@ -507,10 +509,32 @@ constexpr const char* kRuntime = "runtime"; constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools"; /* - * \brief Module attribute for tir constants + * \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The + * node will record the index into this array. See also kConstNameToNDArray below, which is + * the analog for Realy Functions. + * + * Type: Array */ constexpr const char* kConstantsArray = "Constants"; +/*! + * \brief All the runtime::Modules accumulated during compilation by external codegen. These + * modules must be either directly linked or captured in the final compilation artifact. + * + * Type: Array + */ +constexpr const char* kExternalMods = "external_mods"; + +/*! + * \brief All the named runtime::NDArrays accumulated during compilation by external codegen. + * Generally the associated runtime::Module will indicate it requires bindings for these names, + * and during module initialization these bindings will be recovered from a ConstLoaderModule. + * See also kConstantsArray above, which is the analog for PrimFuncs. + * + * Type: Map + */ +constexpr const char* kConstNameToNDArray = "const_name_to_ndarray"; + } // namespace attr } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 9eeb20f5f1ce9..33b36ff8af544 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -707,24 +707,24 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N ------- executor : :py:class:`~tvm.relay.backend.interpreter.Executor` """ + raw_targets = Target.canon_multi_target(target) if mod is None: mod = IRModule() if device is not None: - assert device.device_type == _nd.device(str(target), 0).device_type + assert device.device_type == raw_targets[0].kind.device_type else: - device = _nd.device(str(target), 0) + device = _nd.device(raw_targets[0].kind.device_type, 0) if params is not None: mod = IRModule.from_expr(bind_params_by_name(mod["main"], params)) - if isinstance(target, str): - target = Target(target) if kind == "debug": - return _interpreter.Interpreter(mod, device, target) + assert len(raw_targets) == 1, "The interpreter currently only supports a single target" + return _interpreter.Interpreter(mod, device, raw_targets[0]) if kind == "graph": - return GraphExecutor(mod, device, target) + return GraphExecutor(mod, device, raw_targets) if kind == "vm": - return VMExecutor(mod, device, target) + return VMExecutor(mod, device, raw_targets) if kind == "aot": - return AotExecutor(mod, device, target) + return AotExecutor(mod, device, raw_targets) raise RuntimeError("unknown execution strategy: {0}".format(kind)) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 694dbb45218ca..6229c793b965a 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1386,7 +1386,7 @@ def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""): Parameters ---------- compiler_filter : String - If non-empty, the 'compiler' attribute to filter on. + If non-empty, the "Compiler" attribute to filter on. Returns ------- @@ -1412,7 +1412,7 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""): Parameters ---------- compiler_filter : String - If non-empty, the 'compiler' attribute to filter on. + If non-empty, the "Compiler" attribute to filter on. Returns ------- @@ -1420,3 +1420,25 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""): The pass. """ return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter) + + +def InlineCompilerFunctions(global_vars): + """Inlines all global functions bound to a global var in global_vars. + + Both the global "Compiler" attributed function, and any "Composite" functions it its body are + inlined. + + This pass may be useful for external codegen which needs to undo partitioning based on + properties of the entire partition. + + Parameters + ---------- + global_vars : Array[tvm.relay.GlobalVar] + The global vars of all 'Compiler' functions to inline. + + Returns + ------- + ret : tvm.transform.Pass + The pass. + """ + return _ffi_api.InlineCompilerFunctions(global_vars) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 381cfa0c9d1c8..15013b868f8d0 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1149,11 +1149,21 @@ class AOTExecutorCodegen : public MixedModeVisitor { // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need // to run the LegalizePackedCalls pass. LoweredOutput ret; - ret.params = std::unordered_map>(); - for (auto param : params_) { - ret.params.emplace(std::make_pair( - param.first, - std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + + // Collect any constants extracted by external codegen. + ret.params = std::unordered_map(); + Map const_name_to_ndarray = + lowered_mod + ->GetAttr>(tvm::attr::kConstNameToNDArray, + Map()) + .value(); + for (const auto& kv : const_name_to_ndarray) { + ICHECK(ret.params.emplace(kv.first, kv.second).second); + } + + // Collect any constants extracted during lowering. + for (const auto& kv : params_) { + ICHECK(ret.params.emplace(kv.first, kv.second).second); } // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main @@ -1199,9 +1209,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { lowered_mod = pack_calls(lowered_mod); } - Optional> external_modules = - lowered_mod->GetAttr>("external_mods"); - ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; + // Collect any runtime modules generated by external codegen. + ret.external_mods = lowered_mod + ->GetAttr>(tvm::attr::kExternalMods, + Array()) + .value(); // This is the point where we separate the functions in the module by target VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod); @@ -1214,8 +1226,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { << PrettyPrint(kv.second); } - ret.external_mods = external_modules.value(); - // Extract USMP metadata to pass onto metadata sources Map pool_var_info; std::vector pool_vars; @@ -1282,11 +1292,6 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { String key = args[0]; *rv = get_param_by_name(key); }); - } else if (name == "get_param_id") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - String key = args[0]; - *rv = get_param_id(key); - }); } else if (name == "get_irmodule") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); @@ -1328,17 +1333,11 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { runtime::NDArray get_param_by_name(String key) { auto it = this->output_.params.find(key); CHECK(it != this->output_.params.end()) << "no such parameter " << key; - return (*it).second.second; + return (*it).second; } Array get_external_modules() { return output_.external_mods; } - int get_param_id(String key) { - auto it = this->output_.params.find(key); - CHECK(it != this->output_.params.end()) << "no such parameter " << key; - return (*it).second.first; - } - Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 8c1d83d39b09f..0de90dc977eac 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -86,17 +86,6 @@ struct ExecutorCodegen { return ret; } - std::unordered_map GetParamIds() { - std::unordered_map ret; - auto names = CallFunc>("list_params_name", nullptr); - for (const auto& expr : names) { - // Implicit cast from runtime::String to std::string - std::string key = expr; - ret[key] = CallFunc("get_param_id", key); - } - return ret; - } - Array GetExternalModules() { return CallFunc>("get_external_modules", nullptr); } diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 842ede3bf20b8..60d37aa7509da 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -392,10 +392,15 @@ runtime::Module ACLCompiler(const ObjectRef& ref) { ACLJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto param_names = serializer.GetParams(); + + // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. + const auto* pf = runtime::Registry::Get("runtime.arm_compute_lib_runtime_create"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - runtime::Module lib = (*pf)(func_name, graph_json, param_names); + runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names()); return lib; } diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 72c32fb5b19ee..743525a3c85d1 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -136,11 +136,15 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) { BNNSJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. const auto* pf = runtime::Registry::Get("runtime.BNNSJSONRuntimeCreate"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - auto mod = (*pf)(func_name, graph_json, params); + auto mod = (*pf)(func_name, graph_json, serializer.const_names()); return mod; } diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 4966f3f01c7d2..e6010bcac3b87 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -33,6 +33,8 @@ #include #include #include +#include +#include #include #include "../../../../runtime/contrib/json/json_node.h" @@ -150,7 +152,8 @@ class JSONSerializer : public MemoizedExprTranslator(func_); @@ -162,8 +165,18 @@ class JSONSerializer : public MemoizedExprTranslatorbody); } - /*!\brief Return the required params. */ - Array GetParams() const { return params_; } + /*! + * \brief Returns the accumulated map from constant names to the NDArray they must be bound to + * at runtime. Also referred to a 'params' elsewhere in the code. + */ + const std::unordered_map& const_name_to_ndarray() const { + return const_name_to_ndarray_; + } + + /*! + * \brief Return the constant names in order they were encountered during translation. + */ + const Array& const_names() const { return const_names_; } /*!\brief Return the generated json. */ std::string GetJSON() { @@ -245,11 +258,14 @@ class JSONSerializer : public MemoizedExprTranslator(vn)]; } - std::vector VisitExpr_(const ConstantNode* cn) { - std::string name = symbol_ + "_const_" + std::to_string(params_.size()); - params_.push_back(name); - auto node = std::make_shared(name, "const" /* op_type_ */); - return AddNode(node, GetRef(cn)); + std::vector VisitExpr_(const ConstantNode* constant_node) { + std::string name = symbol_ + "_const_" + std::to_string(const_names_.size()); + VLOG(1) << "Will require parameter '" << name << "' to be available at runtime"; + ICHECK_EQ(const_name_to_ndarray_.count(name), 0); + const_name_to_ndarray_.emplace(name, constant_node->data); + const_names_.push_back(name); + auto node = std::make_shared(name, /*op_type=*/"const"); + return AddNode(node, GetRef(constant_node)); } std::vector VisitExpr_(const TupleNode* tn) { @@ -340,8 +356,17 @@ class JSONSerializer : public MemoizedExprTranslator nodes_; /*! \brief Output of the JSON graph. */ std::vector heads_; - /*! \brief The list of required constants. */ - Array params_; + /*! + * \brief A map from constant names to NDArrays for each Constant encountered during + * translation to JSON. The JSON will record only the constant name. The actual NDArray must + * be made available at runtime from a ConstLoaderModule. + */ + std::unordered_map const_name_to_ndarray_; + /*! + * \brief The domain of the above map, but in order the constants were encountered during + * translation. + */ + Array const_names_; }; } // namespace contrib diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 772007792ae62..d5e9ab6a7bd7f 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -43,6 +43,16 @@ namespace cutlass { namespace { +/*! \brief Return the "cutlass" Target instance to use to guide compilation. */ +Target GetCutlassTarget() { + Target target = Target::Current(/*allow_not_defined=*/true); + if (!target.defined() || target->kind->name != "cutlass") { + // Use the default CUTLASS compilation options. + target = Target("cutlass"); + } + return target; +} + using Str2StrMap = std::unordered_map; static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, @@ -899,14 +909,13 @@ transform::Pass CompileForCutlassImpl() { VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod); const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass"); ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function"; - Optional opt_cutlass_target = Target::Current(); - ICHECK(opt_cutlass_target.defined()) << "Expecting Target::Current to be available"; - ICHECK_EQ(opt_cutlass_target.value()->kind->name, "cutlass"); - runtime::Module runtime_mod = (*pf)(mod, opt_cutlass_target.value()); + Target target = GetCutlassTarget(); + runtime::Module runtime_mod = (*pf)(mod, target); Array external_mods = - mod->GetAttr>("external_mods", Array()).value(); + mod->GetAttr>(tvm::attr::kExternalMods, Array()) + .value(); external_mods.push_back(runtime_mod); - return WithAttr(mod, "external_mods", external_mods); + return WithAttr(mod, tvm::attr::kExternalMods, external_mods); }; return tvm::transform::CreateModulePass(pass_func, 0, "CompileForCutlass", {}); } diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index f17cdafa76a5f..dfe37a3654357 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -585,11 +585,15 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) { DNNLJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - auto mod = (*pf)(func_name, graph_json, params); + auto mod = (*pf)(func_name, graph_json, serializer.const_names()); return mod; #else DNNLModuleCodegen dnnl; diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index 149cc485c7528..20a19a36382f1 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -318,11 +318,16 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) { serializer.serialize(); std::string graph_json = serializer.GetJSON(); VLOG(1) << "TensorRT JSON:" << std::endl << graph_json; - auto param_names = serializer.GetParams(); + + // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'"; - runtime::Module lib = (*pf)(func_name, graph_json, param_names); + runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names()); return lib; } diff --git a/src/relay/backend/contrib/verilator/codegen.cc b/src/relay/backend/contrib/verilator/codegen.cc index 2c29896d1b0e7..ef88266c607f8 100644 --- a/src/relay/backend/contrib/verilator/codegen.cc +++ b/src/relay/backend/contrib/verilator/codegen.cc @@ -111,10 +111,15 @@ runtime::Module VerilatorBackend(const ObjectRef& ref) { VerilatorJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. // Create runtime object - auto n = make_object(func_name, graph_json, params); + auto n = make_object(func_name, graph_json, + serializer.const_names()); // Get Verilator compiler options auto ctx = transform::PassContext::Current(); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index af426e5c71cbf..444d56b125727 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -259,21 +259,33 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator>(); - for (auto param : params_) { - ret.params.emplace(std::make_pair( - param.first, - std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + + // Collect any runtime modules generated by external codegen. + ret.external_mods = + lowered_mod + ->GetAttr>(tvm::attr::kExternalMods, Array()) + .value(); + + // Collect any constants extracted by external codegen. + ret.params = std::unordered_map(); + Map const_name_to_ndarray = + lowered_mod + ->GetAttr>(tvm::attr::kConstNameToNDArray, + Map()) + .value(); + for (const auto& kv : const_name_to_ndarray) { + ICHECK(ret.params.emplace(kv.first, kv.second).second); } - ret.function_metadata = std::move(function_metadata_); - Optional> external_modules = - lowered_mod->GetAttr>("external_mods"); - ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; + // Collect any constants extracted during lowering. + for (const auto& kv : params_) { + ICHECK(ret.params.emplace(kv.first, kv.second).second); + } + + ret.function_metadata = std::move(function_metadata_); // This is the point where we separate the functions in the module by target ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); - ret.external_mods = external_modules.value(); ret.metadata = ExecutorCodegenMetadata({} /* inputs */, {} /* input_tensor_types */, {} /* outputs */, {} /* output_tensor_types */, {} /* pools */, {} /* devices */, @@ -650,14 +662,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { String key = args[0]; auto it = this->output_.params.find(key); CHECK(it != this->output_.params.end()) << "no such parameter " << key; - *rv = (*it).second.second; - }); - } else if (name == "get_param_id") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - String key = args[0]; - auto it = this->output_.params.find(key); - CHECK(it != this->output_.params.end()) << "no such parameter " << key; - *rv = (*it).second.first; + *rv = (*it).second; }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index e9491b0a89010..a81d63296c8a0 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -1196,7 +1196,8 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // annotate the module with the resulting runtime modules. // TODO(mbs): runtime modules should be first class rather than attributes. Array external_mods = - module->GetAttr>("external_mods", Array()).value(); + module->GetAttr>(tvm::attr::kExternalMods, Array()) + .value(); Array new_external_mods = compiler->LowerExternalFunctions(); VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size() << " new external modules"; @@ -1218,7 +1219,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr device_contexts.Set(kv.first, kv.second); // copy-on-write. } - updated_module = WithAttrs(updated_module, {{"external_mods", std::move(external_mods)}, + updated_module = WithAttrs(updated_module, {{tvm::attr::kExternalMods, std::move(external_mods)}, {"device_contexts", std::move(device_contexts)}}); if (backend::IsAutoSchedulerEnabled()) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 70080254c414d..847ea0b3bb22b 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -215,7 +215,11 @@ struct LoweredOutput { Map lowered_funcs; Array external_mods; Map function_metadata; - std::unordered_map> params; + /*! + * \brief Map from constant names (allocated by the codegen as constants are encountered) + * to the constant's value. + */ + std::unordered_map params; ExecutorCodegenMetadata metadata; }; @@ -241,7 +245,7 @@ struct ConstantUpdater : public ExprVisitor { void VisitExpr_(const ConstantNode* cn) final { std::string name = symbol_ + "_const_" + std::to_string(const_idx_++); - VLOG(1) << "Binding " << name << " to constant of type " << PrettyPrint(cn->checked_type()); + VLOG(1) << "binding '" << name << "' to constant of type " << PrettyPrint(cn->checked_type()); (*params_)[name] = cn->data; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 8820a403bf709..23338ea851af1 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1152,11 +1152,31 @@ void VMCompiler::Codegen() { for (const auto& kv : per_tvm_target_modules) { ICHECK(kv.first->kind->device_type != kDLExtDev); } - Array ext_mods = - context_.module->GetAttr>("external_mods", Array()) + + // Retrieve all external runtime modules accumulated by external codegen (both function-at-a-time + // and IRModule-at-a-time). + Array external_mods = + context_.module + ->GetAttr>(tvm::attr::kExternalMods, Array()) + .value(); + + // Retrieve any constant bindings accumulated by external codegen (by IRModule-at-a-time passes). + Map const_name_to_ndarray = + context_.module + ->GetAttr>(tvm::attr::kConstNameToNDArray, + Map()) .value(); - VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build and " << ext_mods.size() - << " external runtime modules"; + + VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build, " + << external_mods.size() << " external runtime modules, " << const_name_to_ndarray.size() + << " external constants, and " << params_.size() << " local constants"; + + // Any constant bindings must be merged into the overall 'params' map we've directly accumulated + // via the TECompiler callback. + for (const auto& kv : const_name_to_ndarray) { + ICHECK_EQ(params_.count(kv.first), 0); + params_.emplace(kv.first, kv.second); + } runtime::Module lib; if (per_tvm_target_modules.empty()) { @@ -1169,7 +1189,7 @@ void VMCompiler::Codegen() { } lib = - codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, + codegen::CreateMetadataModule(params_, lib, external_mods, config_->host_target, Runtime::Create("cpp"), Executor::Create("graph"), // DNS HACK relay::backend::ExecutorCodegenMetadata()); exec_->SetLib(lib); diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 3df07e4c57f5b..4fc8a01ed75a1 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -27,12 +27,28 @@ #include "../op/call/call.h" #include "tvm/relay/analysis.h" #include "tvm/relay/expr_functor.h" +#include "tvm/relay/transform.h" namespace tvm { namespace relay { namespace transforms { namespace { +/*! + * \brief Returns the \p FunctionNode of if \p expr if it is a "Compiler" function which should + * be processed by a pass using \p compiler_filter. Otherwise returns null. + */ +const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler_filter) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && + (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { + return function_node; + } + } + return nullptr; +} + /*! * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given * module will be extended with the newly outlined functions. @@ -44,26 +60,22 @@ class Outliner : public MixedModeMutator { Expr Rewrite_(const CallNode* pre, const Expr& post) final { Call new_call = Downcast(post); - if (const auto* function_node = new_call->op.as()) { - Optional opt_compiler = function_node->GetAttr(attr::kCompiler); - if (opt_compiler.defined() && - (compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) { - auto function = GetRef(function_node); - DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler - << "' attribute should not have free variables"; - // Ask the cache to supply a unique global var for this function. - GlobalVar global_symbol = cache_->GetGlobalSymbol(function); - // Depending on the cache's implementation, two structurally equal (but not object equal) - // functions may be assigned the same global symbol. If so we'll lift it just once, but - // rewrite all the calls. - if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { - function = - WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); - mod_->Add(global_symbol, function); - } - // Update the call. - return WithFields(new_call, global_symbol); + if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) { + auto function = GetRef(function_node); + DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler + << "' attribute should not have free variables"; + // Ask the cache to supply a unique global var for this function. + GlobalVar global_symbol = cache_->GetGlobalSymbol(function); + // Depending on the cache's implementation, two structurally equal (but not object + // equal) functions may be assigned the same global symbol. If so we'll lift it just + // once, but rewrite all the calls. + if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { + function = + WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); + mod_->Add(global_symbol, function); } + // Update the call. + return WithFields(new_call, global_symbol); } return post; } @@ -71,8 +83,8 @@ class Outliner : public MixedModeMutator { private: /*! * \brief A cached mapping from functions to global variables. Depending on the implementation - * the cache may generate fresh symbols or require the function to already have a "global_symbol" - * attribute, and may share symbols between structurally equal functions. + * the cache may generate fresh symbols or require the function to already have a + * "global_symbol" attribute, and may share symbols between structurally equal functions. */ GlobalSymbolCache* cache_; /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ @@ -81,6 +93,72 @@ class Outliner : public MixedModeMutator { IRModule mod_; }; +/*! + * \brief Inline immediate calls to "Composite" functions. + */ +class InnerInliner : public MixedModeMutator { + public: + InnerInliner() = default; + + private: + using MixedModeMutator::Rewrite_; + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* function_node = new_call->op.as()) { + ICHECK(function_node->GetAttr(attr::kComposite).defined()); + ICHECK_EQ(function_node->params.size(), new_call->args.size()); + Map subst; + for (size_t i = 0; i < new_call->args.size(); ++i) { + subst.Set(function_node->params[i], new_call->args[i]); + } + return Bind(function_node->body, subst); + } + return post; + } +}; + +/*! + * \brief Inline calls to global "Compiler" functions with global var in \p global_vars. + * Both the 'outer' "Compiler" function and any 'inner' "Composite" functions in its body + * are inlined. + */ +class OuterInliner : public MixedModeMutator { + public: + OuterInliner(IRModule mod, Array global_vars_) + : mod_(std::move(mod)), global_vars_(std::move(global_vars_)) {} + + private: + using MixedModeMutator::Rewrite_; + + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + Call new_call = Downcast(post); + if (const auto* global_var_node = new_call->op.as()) { + auto global_var = GetRef(global_var_node); + if (std::find(global_vars_.begin(), global_vars_.end(), global_var) != global_vars_.end()) { + BaseFunc base_func = mod_->Lookup(global_var); + const auto* function_node = base_func.as(); + ICHECK(function_node); + ICHECK(function_node->GetAttr(attr::kCompiler).defined()); + ICHECK_EQ(function_node->params.size(), new_call->args.size()); + Map subst; + for (size_t i = 0; i < new_call->args.size(); ++i) { + subst.Set(function_node->params[i], new_call->args[i]); + } + Expr new_body = InnerInliner().VisitExpr(function_node->body); + return Bind(new_body, subst); + } + } + return post; + } + + private: + /*! \brief Original module we are processing. */ + IRModule mod_; + /*! \brief Global vars of functions to inline. */ + Array global_vars_; +}; + } // namespace GlobalSymbolCache::~GlobalSymbolCache() = default; @@ -106,10 +184,10 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach runtime::TypedPackedFunc pass_func = [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( IRModule mod, transform::PassContext ctx) { - IRModule output_mod = GetRef(mod.CopyOnWrite()); + VLOG(1) << "OutlineCompilerFunctions input:" << std::endl << PrettyPrint(mod); + IRModule output_mod = mod->ShallowCopy(); for (const auto& kv : mod->functions) { - const FunctionNode* function_node = AsOptimizableFunctionNode(kv.second); - if (function_node) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { Expr new_body = Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body); Function new_function = @@ -117,6 +195,7 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach output_mod->Add(kv.first, new_function); } } + VLOG(1) << "OutlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod); return output_mod; }; @@ -132,31 +211,57 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { + VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod); IRModule output_mod = mod->ShallowCopy(); for (const auto& kv : mod->functions) { - if (const auto* function_node = kv.second.as()) { - Optional opt_compiler = function_node->GetAttr(attr::kCompiler); - if (opt_compiler.defined() && - (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { - auto new_function = WithFields( - GetRef(function_node), function_node->params, function_node->body, - function_node->ret_type, function_node->type_params, - /* erase attributes */ DictAttrs(Map())); - new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); - output_mod->Update(kv.first, new_function); - } + if (const auto* function_node = AsFunctionNode(kv.second, compiler_filter)) { + auto new_function = + WithFields(GetRef(function_node), function_node->params, + function_node->body, function_node->ret_type, function_node->type_params, + /* erase attributes */ DictAttrs(Map())); + new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); + output_mod->Update(kv.first, new_function); } } + VLOG(1) << "MarkCompilerFunctionsAsExtern result:" << std::endl << PrettyPrint(output_mod); return output_mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); } +transform::Pass InlineCompilerFunctions(Array global_vars) { + runtime::TypedPackedFunc pass_func = + [global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) { + VLOG(1) << "InlineCompilerFunctions with global_vars: " << PrettyPrint(global_vars); + if (global_vars.empty()) { + return mod; + } + VLOG(1) << "InlineCompilerFunctions input:" << std::endl << PrettyPrint(mod); + IRModule output_mod = mod->ShallowCopy(); + for (const auto& kv : mod->functions) { + if (std::find(global_vars.begin(), global_vars.end(), kv.first) != global_vars.end()) { + output_mod->Remove(kv.first); + } else if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + Expr new_body = OuterInliner(mod, global_vars).VisitExpr(function_node->body); + Function new_function = + WithFields(GetRef(function_node), /*opt_params=*/{}, new_body); + output_mod->Add(kv.first, new_function); + } + } + VLOG(1) << "InlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod); + return output_mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsImpl", {}); +} + TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols") .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols); TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") .set_body_typed(MarkCompilerFunctionsAsExtern); +TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctions") + .set_body_typed(InlineCompilerFunctions); } // namespace transforms } // namespace relay diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index 9d1dcd9f21a22..5cd89cbb9d2de 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -22,10 +22,10 @@ * \brief Helper passes for working with functions with the "Compiler" attribute. * * Those wishing to use the "RelayToTIR" custom pass machinery to do IRModule-at-a-time external - * codegen may find the following two helper passes useful: + * codegen may find the following helpers useful: * - * - \p OutlineCompilerFunctionsWithExistingGlobalSymbols will lift inline functions with a - * matching "Compiler" attribute to be global functions, using the "global_symbol" attribute + * - The \p OutlineCompilerFunctionsWithExistingGlobalSymbols pass will lift inline functions with + * a matching "Compiler" attribute to be global functions, using the "global_symbol" attribute * already assigned. Can be used before custom lowering. * * Note that ideally "Compiler" attributed functions would be made global functions as early as @@ -36,15 +36,22 @@ * * See also OutlineCompilerFunctionsMutator in src/relay/backend/contrib/ethosu/codegen.cc. * - * - (\p OutlineCompilerFunctions is a more general version of the above which can use a custom - * cache to both allocate "global_symbol" names and ensure two strucurally equal functions are - * assigned the same name, and thus lowered only once. This is used by Collage when preparing - * the optimally partitioned IRModule). + * - (The \p OutlineCompilerFunctions pass is a more general version of the above which can use + * a custom cache to both allocate "global_symbol" names and ensure two structurally equal + * functions are assigned the same name, and thus lowered only once. This is used by Collage + * when preparing the optimally partitioned IRModule). * - * - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler" - * attribute with the same function with just an "Extern" attribute, signalling the function - * has been dealt with. However calls to such functions will be left unchanged. Can be used - * after lowering to cleanup the IRModule. + * - The \p MarkCompilerFunctionsAsExtern pass will update the attributes of global functions + * with a matching "Compiler" attribute to have just the "Extern" attribute. That will signal + * the function has been dealt with. However calls to such functions will be left unchanged. + * Can be used after lowering to cleanup the IRModule. + * + * - The \p InlineCompilerFunctions pass can selectively inline global functions with a matching + * "Compiler" attribute who's name appears in the given set. Obviously it's more sensible to + * not create that function in the first place, however some external codegen have rules to + * accept or reject partitionings based on the overall partitioned function body. This pass + * can be used do the legwork, and will take care to not only inline the outer "Compiler" + * annotated funcition, but also any "Composite" annotated functions in its body. */ #ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_ @@ -126,6 +133,16 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co */ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); +/*! + * \brief A pass to inline all global "Compiler" functions which are bound to a global var + * in \p global_vars. Both the global function and any "Composite" functions it its body are + * inlined. + * + * This pass may be useful for external codegen which needs to undo partitioning based on + * properties of the entire partition. + */ +transform::Pass InlineCompilerFunctions(Array global_vars); + } // namespace transforms } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index 00953a1907e13..f52e95b2adbfc 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -148,7 +148,7 @@ class TargetHookVisitor : public MixedModeVisitor { Pass RelayToTIRTargetHook(CompilationConfig config) { auto pass_func = [config = std::move(config)](IRModule mod, const PassContext& pass_ctx) { - VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); + VLOG(1) << "RelayToTIRTargetHook before:" << std::endl << PrettyPrint(mod); TargetHookVisitor target_hook_visitor(mod, config); std::vector custom_passes = target_hook_visitor.Visit(); for (const auto& custom_pass : custom_passes) { @@ -161,11 +161,14 @@ Pass RelayToTIRTargetHook(CompilationConfig config) { mod = custom_pass.pass(mod); } else { // Invoke the pass. + // Note that there may be a non-external codegen target in scope. Each custom pass + // must be prepared to handle this, eg by creating a default target instance if the + // current target is either null or of a generic kind such as 'cuda' or 'llvm'. VLOG(0) << "Invoking custom pass for target kind '" << custom_pass.target_kind_name << "'"; mod = custom_pass.pass(mod); } } - VLOG(1) << "After:" << std::endl << PrettyPrint(mod); + VLOG(1) << "RelayToTIRTargetHook after:" << std::endl << PrettyPrint(mod); return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {}); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 97299c63752d7..dd89edea25f56 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -202,6 +202,8 @@ runtime::Module CreateMetadataModule( String symbol = pf_sym(); Array variables = pf_var(); for (size_t i = 0; i < variables.size(); i++) { + VLOG(1) << "From module of type '" << mod->type_key() << "' found const var '" + << variables[i] << "' for symbol '" << symbol << "'"; symbol_const_vars.push_back(variables[i].operator std::string()); } ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated symbol: " << symbol; diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py index b9eb115475956..d2476f2361db1 100644 --- a/tests/python/relay/transform/test_compiler_function_utils.py +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -42,7 +42,7 @@ def make_consts(dtype, shapes): } -def inlined_mod(): +def original_mod(): return tvm.parser.parse( """ #[version = "0.0.5"] @@ -143,10 +143,35 @@ def @tvmgen_default_cutlass_main_0(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i ) +def expected_inlined_mod(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + %0 = nn.dense(%x0, meta[relay.Constant][0], units=2304); + %1 = add(%0, meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + """, + "from_string", + None, + metatable, + ) + + def test_outline_compiler_functions_with_existing_global_symbols(): actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols( "cutlass" - )(inlined_mod()) + )(original_mod()) tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) @@ -157,5 +182,12 @@ def test_mark_compiler_functions_as_extern(): tvm.ir.assert_structural_equal(actual_extern_mod, expected_extern_mod(), map_free_vars=True) +def test_inline_compiler_functions(): + mod = expected_outlined_mod() + gv = mod.get_global_var("tvmgen_default_cutlass_main_0") + actual_inlined_mod = tvm.relay.transform.InlineCompilerFunctions([gv])(mod) + tvm.ir.assert_structural_equal(actual_inlined_mod, expected_inlined_mod(), map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main()