From 700669b8dc5a25f5e863fab416f06de81af8203a Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Mon, 22 Nov 2021 14:23:43 -0800 Subject: [PATCH] Prepare for switching VM to LowerTEPass. This is a grab bag of fallout changes from switching the VM to use LoweTEPass which can be ealy split out of the main #9483 PR. - AnnotateSpans can be used from C++ (though, unfortunately, it didn't help me with debugging since spans are universally dropped in most passes). - Can get a human readable dump of the VM's PackedFunc names and indexes for debugging. - If TVM_LOG_DEBUG defined then include types and ids of GlobalVars. I had a lot of difficulty tracking down where duplicate GlobalVars for the same name_hint were getting created and propagated. - GetCallLoweredProps follows same API as GetDeviceCopy and GetOnDevice where will return 'null' properties if call/expr is not of call_lowered form. Mildly more convenient, though switching all the above to ICHECK and push 'if (op == the relevant op)' into all use sites would also be just fine. - Misc VLOG improvements made while tracking down issues in #9483. --- include/tvm/parser/parser.h | 10 +++++- include/tvm/runtime/vm/executable.h | 11 ++++-- include/tvm/target/compilation_config.h | 22 ++++++++---- include/tvm/target/se_scope.h | 6 ++++ python/tvm/runtime/vm.py | 7 ++++ src/ir/module.cc | 13 ++++--- src/parser/parser.cc | 27 ++++++++------ src/printer/relay_text_printer.cc | 12 ++++++- src/printer/text_printer.cc | 20 +++++++++-- src/relay/backend/aot_executor_codegen.cc | 33 +++++++++++++---- src/relay/backend/graph_executor_codegen.cc | 4 +-- src/relay/backend/interpreter.cc | 5 +-- src/relay/backend/te_compiler.cc | 35 +++++++++++++----- src/relay/backend/te_compiler_cache.cc | 6 ++-- src/relay/backend/utils.cc | 5 +++ src/relay/backend/vm/compiler.cc | 5 ++- src/relay/backend/vm/compiler.h | 9 ++++- src/relay/op/call/call.cc | 40 +++++++++++++-------- src/relay/op/call/call.h | 10 ++++-- src/relay/op/memory/device_copy.cc | 11 ++++++ src/relay/op/memory/device_copy.h | 10 ++++++ src/relay/transforms/device_domains.cc | 32 +++-------------- src/runtime/vm/executable.cc | 33 ++++++++++++++--- src/runtime/vm/vm.cc | 13 +++++++ src/target/compilation_config.cc | 12 ++++--- src/target/target.cc | 2 +- tests/cpp/target/compilation_config_test.cc | 30 +++++++++++++--- 27 files changed, 311 insertions(+), 112 deletions(-) diff --git a/include/tvm/parser/parser.h b/include/tvm/parser/parser.h index 8c2722050905..0a73e1a2a532 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/parser/parser.h @@ -20,10 +20,11 @@ #ifndef TVM_PARSER_PARSER_H_ #define TVM_PARSER_PARSER_H_ /*! - * \file parser.h + * \file include/tvm/parser/parser.h * \brief A parser for TVM IR. */ #include +#include #include #include @@ -39,6 +40,13 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte const Optional& init_module = Optional(), const MetaTable& init_meta_table = MetaTable()); +/*! + * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources + * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for + * modules constructed programaticaly rather than textually. + */ +transform::Pass AnnotateSpans(); + } // namespace parser } // namespace tvm diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 311667904df6..5b18d05af80c 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -144,6 +144,13 @@ class Executable : public ModuleNode { */ std::string GetVirtualDevices() const; + /*! + * \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable. + * These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by + * a BYOC external codegen. + */ + std::string GetPrimitives() const; + /*! * \brief Print the detailed statistics of the given code, i.e. number of * globls and constants, etc. @@ -201,9 +208,9 @@ class Executable : public ModuleNode { int host_device_index = -1; /*! \brief The global constant pool. */ std::vector constants; - /*! \brief A map from globals (as strings) to their index in the function map. */ + /*! \brief A map from globals (as strings) to their index in the Relay function map. */ std::unordered_map global_map; - /*! \brief A mapping from the packed function (as string) to the index that + /*! \brief A mapping from the packed function's global name (as string) to the index that * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. */ std::unordered_map primitive_map; diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h index facb74d6278e..45ff774f1742 100644 --- a/include/tvm/target/compilation_config.h +++ b/include/tvm/target/compilation_config.h @@ -32,7 +32,7 @@ namespace tvm { /*! * \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to - * compile a Relay module. All centralizes any setup and validation logic needed to transition + * compile a Relay module. Centralizes any setup and validation logic needed to transition * from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly * (eg a a list of \p Targets) to the configuration. * @@ -49,9 +49,12 @@ namespace tvm { class CompilationConfigNode : public Object { public: /*! - * \brief The legacy targets map, mapping device type to \p Targets. Does not include any - * entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType, - * though we want to get rid of that limitation. + * \brief The legacy targets map, mapping device type to the corresponding \p Target to use + * when compiling primitive functions. Does not include an entry for the host target, however + * each \p Target in this map will have it's \p host field set to the \p host_target. + * + * Currently we require at most one \p Target per \p DLDeviceType, though we want to get rid of + * that limitation. * * CAUTION: Since keys are \p Integers they are compared by object equality not integer * value. @@ -63,13 +66,18 @@ class CompilationConfigNode : public Object { /*! * \brief The host target. Used for 'scalar' data and code (such as shapes and shape * functions) and residual Relay expressions and data (such as conditionals and ADTs). + * + * Note that it is possible for a \p Target used for primitive operations to be structurally + * equal to the host \p Target (up to the \p host field.) However the \p Target objects will + * be distinct, and can be used as keys within a \p Map without collision. */ Target host_target; /*! - * \brief Vector of all available targets for primitive operators. May contain a \p Target - * for the same device type as for the \p host_target, however the \p host_target should - * be preferred for all host computations and data. + * \brief Vector of all available \p Targets for compiling primitive operators. May contain + * a \p Target for the same device type as for the \p host_target, however the \p host_target + * should be used for all host computations and data. Each \p Target will have \p host_target + * as its host. */ Array primitive_targets; diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h index 981a0b85ab13..b5928f93052c 100644 --- a/include/tvm/target/se_scope.h +++ b/include/tvm/target/se_scope.h @@ -297,6 +297,12 @@ class SEScope : public ObjectRef { return SEScope(device.device_type, device.device_id, std::move(target)); } + /*! \brief Returns the \p SEScope for \p target. */ + static SEScope ForTarget(Target target) { + return SEScope(static_cast(target->kind->device_type), /*virtual_device_id=*/0, + std::move(target)); + } + /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */ TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, MemoryScope memory_scope) { diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 365e38c6e06c..1c11009a99a9 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -73,6 +73,7 @@ def __init__(self, mod): self._get_bytecode = self.mod["get_bytecode"] self._get_constants = self.mod["get_constants"] self._get_virtual_devices = self.mod["get_virtual_devices"] + self._get_primitives = self.mod["get_primitives"] self._get_stats = self.mod["get_stats"] self._get_function_arity = self.mod["get_function_arity"] self._get_function_param_name = self.mod["get_function_param_name"] @@ -257,6 +258,12 @@ def virtual_devices(self): """Returns a human-readable description of all the (virtual) devices in the executable.""" return self._get_virtual_devices() + @property + def primitive(self): + """Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the + executable""" + return self._get_primitives() + @property def globals(self): """Get the globals used by the Relay VM executable. diff --git a/src/ir/module.cc b/src/ir/module.cc index c63e1df79f2e..6f2c9f9fe994 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -213,7 +213,7 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { ICHECK_EQ((*it).second, var); } else { ICHECK(global_var_map_.count(var->name_hint) == 0) - << "Duplicate global function name " << var->name_hint; + << "Duplicate global function name " << PrettyPrint(var); } global_var_map_.Set(var->name_hint, var); @@ -243,7 +243,7 @@ void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& if (!update) { // set global type var map ICHECK(global_type_var_map_.count(var->name_hint) == 0) - << "Duplicate global type definition name " << var->name_hint; + << "Duplicate global type definition name " << PrettyPrint(var); } global_type_var_map_.Set(var->name_hint, var); RegisterConstructors(var, type); @@ -266,7 +266,7 @@ void IRModuleNode::Remove(const GlobalVar& var) { BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); - ICHECK(it != functions.end()) << "There is no definition of " << var->name_hint; + ICHECK(it != functions.end()) << "There is no definition of " << PrettyPrint(var); return (*it).second; } @@ -277,7 +277,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const { TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); - ICHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint; + ICHECK(it != type_definitions.end()) << "There is no definition of " << PrettyPrint(var); return (*it).second; } @@ -306,6 +306,10 @@ String IRModuleNode::GetUniqueName(const String& name) { } } +/*! + * \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs + * ('one') side above the rhs ('two'). + */ struct Renamer : relay::ExprMutator, TypeMutator { Map defs; Map types; @@ -411,7 +415,6 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, const Mapimport_set_.count(path) == 0) { this->import_set_.insert(path); - DLOG(INFO) << "Importing: " << path; std::fstream src_file(path, std::fstream::in); std::string file_contents{std::istreambuf_iterator(src_file), std::istreambuf_iterator()}; diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 092d5b61eeec..44aeeb3bdee1 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1909,7 +1909,8 @@ Parser InitParser(const std::string& file_name, const std::string& file_content, IRModule ParseModule(const std::string& file_name, const std::string& file_content, const Optional& init_module, const MetaTable& init_meta_table) { - VLOG(9) << "ParseModule"; + VLOG_CONTEXT << "ParseModule"; + VLOG(9) << "parsing and type-checking " << file_name; auto parser = InitParser(file_name, file_content, init_module, init_meta_table); auto mod = parser.ParseModule(); ICHECK(mod.defined()) << "The parser must return a non-null module."; @@ -1952,15 +1953,21 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr") return ParseExpr(file_name, file_content); }); -TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() { - return CreateModulePass( - [](const IRModule& mod, const PassContext& ctx) { - String text = AsText(mod, /*show_meta_data=*/true); - VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; - return ParseModule("GeneratedSource", text); - }, - 0, "AnnotateSpans", {}); -}); +/*! + * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources + * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for + * modules constructed programaticaly rather than textually. + */ +Pass AnnotateSpans() { + auto pass_func = [](const IRModule& mod, const PassContext& ctx) { + String text = AsText(mod, /*show_meta_data=*/true); + VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; + return ParseModule("GeneratedSource", text); + }; + return CreateModulePass(pass_func, 0, "AnnotateSpans", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans); } // namespace parser } // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 7454cfdf336e..ed5f5f62af94 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -499,7 +499,17 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { return PrintFunc(Doc::Text("fn "), GetRef(op)); } -Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); } +Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { + Doc doc; + doc << "@" << op->name_hint; +#if TVM_LOG_DEBUG + if (op->checked_type_.defined()) { + doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */"; + } + doc << " /* id=" << reinterpret_cast(op) << " */"; +#endif + return doc; +} Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); } diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 444cb0828c94..5acb9bd3f1dc 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -56,13 +56,29 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { if (kv.second.as()) { std::ostringstream os; os << "def @" << kv.first->name_hint; +#if TVM_LOG_DEBUG + os << " /* id=" << reinterpret_cast(kv.first.get()) << " */"; +#endif doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); } else if (kv.second.as()) { - doc << "@" << kv.first->name_hint << " = "; - doc << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); + doc << "@" << kv.first->name_hint; +#if TVM_LOG_DEBUG + doc << " /* id=" << reinterpret_cast(kv.first.get()) << " */"; +#endif + doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); } doc << Doc::NewLine(); } +#if TVM_LOG_DEBUG + // attributes + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + doc << "attributes {" << Doc::NewLine(); + for (const auto& kv : mod->attrs->dict) { + doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine(); + } + doc << "}" << Doc::NewLine(); + } +#endif return doc; } diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 22a6542c8b9c..773f5edbcc2b 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -83,8 +83,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { Expr func; Array args; - if (call_node->op == CallLoweredOp()) { - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + if (call_lowered_props.lowered_func.defined()) { func = call_lowered_props.lowered_func; args = call_lowered_props.arguments; } else { // Relay functions that have not been lowered and lowered extern functions @@ -516,10 +516,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { } call_lowered_props = CallLoweredProps{GetRef(gvn), call_node->args, {}}; } else { - ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try " - "applying the fuse_ops transformation to the " - "expression."; call_lowered_props = GetCallLoweredProps(call_node); + ICHECK(call_lowered_props.lowered_func.defined()) + << "Operators should be transformed away; Try " + "applying the fuse_ops transformation to the " + "expression."; for (const auto& arg : call_lowered_props.arguments) { VisitExpr(arg); } @@ -717,6 +718,14 @@ class AOTExecutorCodegen : public MixedModeVisitor { : mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)) {} LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { + VLOG_CONTEXT << "AOT"; + for (const auto& kv : targets_) { + VLOG(1) << "target: " << kv.second->ToDebugString(); + } + if (target_host_.defined()) { + VLOG(1) << "target host: " << target_host_->ToDebugString(); + } + Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); Integer workspace_byte_alignment = @@ -793,10 +802,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); } - // Build the TIR IRModule for the AOT function + // Build the TIR IRModule for the main AOT function Map symbol_map; symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); IRModule mod_run(symbol_map, {}, {}, {}, mod->attrs); + VLOG(1) << "main module:" << std::endl << PrettyPrint(mod_run); // Apply storage rewrite pass to the runner function to do memory planning auto storage_rewrite = tir::transform::StorageRewrite(); @@ -827,12 +837,23 @@ class AOTExecutorCodegen : public MixedModeVisitor { ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; // This is the point where we separate the functions in the module by target + VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod); ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); + VLOG(1) << "per-target modules:"; + for (const auto& kv : ret.lowered_funcs) { + VLOG(1) << "target:" << std::endl + << kv.first->ToDebugString() << std::endl + << "maps to:" << std::endl + << PrettyPrint(kv.second); + } + ret.external_mods = external_modules.value(); if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { + VLOG(1) << "merging main into existing module for host target"; ret.lowered_funcs[target_host_]->Update(mod_run); } else { + VLOG(1) << "adding main into new module for host target"; ret.lowered_funcs.Set(target_host_, mod_run); } diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 1e647b5ba1d3..ee748d0badfe 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -407,9 +407,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator inputs; std::string func_name; - if (call->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + if (call_lowered_props.lowered_func.defined()) { // Extract function and arguments from the call_lowered op - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); func_name = call_lowered_props.lowered_func->name_hint; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index a87c60a05354..d9896c368416 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -684,8 +684,9 @@ class Interpreter : public ExprFunctor, } ObjectRef VisitExpr_(const CallNode* call_node) final { - if (call_node->op == CallLoweredOp()) { // Special case: Call a lowered TIR function. - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + if (call_lowered_props.lowered_func.defined()) { + // Special case: Call a lowered TIR function. // Evaluate only function args std::vector args; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index cfb4c7923a49..b339828b0cd4 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -100,7 +100,7 @@ class TECompilerImpl : public TECompilerNode { IRModule lowered_mod = lowered_func->cached_func->funcs; // Annotate functions with their target and put them in the return module - for (auto kv : lowered_mod->functions) { + for (const auto& kv : lowered_mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; @@ -114,6 +114,7 @@ class TECompilerImpl : public TECompilerNode { } } } + // Extract lowered dynamic shape functions from the shape cache for (const auto& it : shape_func_cache_) { auto source_func = it.first; @@ -129,6 +130,7 @@ class TECompilerImpl : public TECompilerNode { mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target)); } } + return mod; } @@ -202,10 +204,16 @@ class TECompilerImpl : public TECompilerNode { private: // implement lowered func CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { + VLOG(1) << "lowering:" << std::endl + << PrettyPrint(key->source_func) << std::endl + << "for target:" << std::endl + << key->target->ToDebugString(); std::lock_guard lock(mutex_); CCacheValue value; auto it = cache_.find(key); if (it != cache_.end()) { + VLOG(1) << "already lowered to:" << std::endl + << PrettyPrint(it->second->cached_func->prim_fn_var); it->second->use_count += 1; if (it->second->cached_func.defined()) return it->second; value = it->second; @@ -232,6 +240,11 @@ class TECompilerImpl : public TECompilerNode { // Collect these here as it's removed in LowerExternalFunctions() std::string codegen_name = key->source_func->GetAttr(attr::kCompiler).value(); device_contexts_.Set(global_var, codegen_name); + VLOG(1) << "preparing to use external codegen '" << codegen_name + << "' with name:" << std::endl + << PrettyPrint(value->cached_func->prim_fn_var) << std::endl + << "and definitions:" << std::endl + << PrettyPrint(value->cached_func->funcs); return value; } @@ -266,11 +279,19 @@ class TECompilerImpl : public TECompilerNode { cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); } value->cached_func = cfunc; + VLOG(1) << "lowered to name:" << std::endl + << PrettyPrint(value->cached_func->prim_fn_var) << std::endl + << "with definitions:" << std::endl + << PrettyPrint(value->cached_func->funcs); return value; } // implement lowered shape func CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { + VLOG(1) << "lowering dynamic shape function:" << std::endl + << PrettyPrint(key->source_func) << std::endl + << "for target:" << std::endl + << key->target->ToDebugString(); std::lock_guard lock(mutex_); CCacheValue value; auto it = shape_func_cache_.find(key); @@ -295,6 +316,10 @@ class TECompilerImpl : public TECompilerNode { }); value->cached_func = cached_func; + VLOG(1) << "lowered to name:" << std::endl + << PrettyPrint(value->cached_func->prim_fn_var) << std::endl + << "with definitions:" << std::endl + << PrettyPrint(value->cached_func->funcs); return value; } @@ -480,11 +505,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } else { // Non-External Relay Function - VLOG(1) << "lowering to target " << target->ToDebugString() << " for primitive:\n" - << PrettyPrint(func); CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); - VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; // Collect all the lowered functions produced for this primitive function. Map prim_fns; @@ -493,7 +515,6 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { CHECK(prim_fn.second.as()) << "must be a prim fn"; prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); all_prim_fn_vars.push_back(prim_fn.first); - VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'"; } // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT @@ -529,8 +550,6 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // on the host cpu irrespective of where the primitive runs. // TODO(mbs): Cleanup target handling. Target shape_target("llvm"); - VLOG(1) << "lowering to target " << shape_target->ToDebugString() - << " for dynamic shape function for primitive"; CCacheKey shape_key(func, shape_target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -925,7 +944,7 @@ void UpdateFunctionMetadata(BaseFunc func, std::move(workspace_sizes), std::move(io_sizes), std::move(constant_sizes), std::move(tir_primfuncs), std::move(relay_primfuncs)); - VLOG(1) << "FunctionInfo: " << prim_fn_var.value()->name_hint << " = " << PrettyPrint(fi); + VLOG(1) << "FunctionInfo: " << PrettyPrint(prim_fn_var.value()) << " = " << PrettyPrint(fi); // The primitive function name here corresponds to the string we will use to generate // this Relay function at the low level. diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 266bd719545a..91265c46dcb4 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -215,7 +215,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Unexpected free variable " << op->name_hint(); + LOG(FATAL) << "Unexpected free variable " << PrettyPrint(GetRef(op)); return {}; } @@ -384,7 +384,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> CachedFunc Create(const Function& prim_func, const Target& target, std::function renamer) { - Array inputs; TShapeDataDependent shape_func_param_states; for (auto param : prim_func->params) { @@ -429,6 +428,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } // Set all the inputs correctly. + Array inputs; for (auto param : prim_func->params) { int state = param_states_[param]; shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); @@ -496,7 +496,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> return VisitExpr(it->second); } if (param_states_.find(var) == param_states_.end()) { - LOG(FATAL) << "Unexpected free variable " << var->name_hint(); + LOG(FATAL) << "Unexpected free variable " << PrettyPrint(var); return {}; } else { ICHECK(data_dependents_per_input_.size()); diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 9a1c428482e2..54275430434b 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -24,6 +24,7 @@ #include "utils.h" +#include #include #include "te_compiler.h" @@ -177,6 +178,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) Array GetPassPrefix(bool is_homegeneous, bool is_vm) { Array pass_seqs; + // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton + // by most passes there's little utility in including this now. Plus we'd need to only do + // this if there's no existing spans to work from. + // pass_seqs.push_back(parser::AnnotateSpans()); Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 9bdd63a4b126..f2b094700a1b 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -996,7 +996,10 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) VLOG(1) << std::endl << "-------------------------------------------------" << std::endl - << exec_->GetVirtualDevices() << exec_->GetConstants() << exec_->GetBytecode() + << exec_->GetVirtualDevices() // + << exec_->GetConstants() // + << exec_->GetPrimitives() // + << exec_->GetBytecode() // << "-------------------------------------------------"; backend::UpdateAutoSchedulerOpWeights(context_.compiler); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 2edec70d5c3b..efe50c40f3d3 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -107,7 +107,14 @@ class VMCompiler : public runtime::ModuleNode { void SetParam(const std::string& name, runtime::NDArray data_in); /*! - * \brief Lower the functions in a Module + * \brief Lower the functions in a Module. + * + * ---------------------------------------------------------------------------------- + * | This is the main entry point for the VM compilation flow. | + * | - Preceded by \p SetParam for the global params. | + * | - Followed by \p Codegen() to finalize the executable. | + * | - Then the result runtime::Module can be constructed from the internal exec_. | + * ---------------------------------------------------------------------------------- * * \param mod Relay Module * \param targets For heterogeneous compilation, it is a dictionary indicating device type diff --git a/src/relay/op/call/call.cc b/src/relay/op/call/call.cc index 87f18e860950..9e7f68de0dc8 100644 --- a/src/relay/op/call/call.cc +++ b/src/relay/op/call/call.cc @@ -67,7 +67,8 @@ Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_ar // Right now, call_lowered only supports func being a global var pointing to the lowered // function. ICHECK(func.as()) - << "Function to call should be GlobalVarNode, but got " << func->GetTypeKey(); + << "Function to call should be GlobalVarNode, but got:" << std::endl + << PrettyPrint(func); ICHECK(attrs.as()) << "Expected attributes to be CallLoweredAttrs, but got " << attrs->GetTypeKey(); return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))}, std::move(attrs), @@ -95,20 +96,29 @@ RELAY_REGISTER_OP("call_lowered") .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); CallLoweredProps GetCallLoweredProps(const CallNode* call_node) { - ICHECK(call_node->op == CallLoweredOp()) - << "GetCallLoweredProps expects the op to be call_lowered. "; - ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 arguments. "; - const auto* function = call_node->args[0].as(); - ICHECK(function) << "Expected first arg to call_lowered to be a GlobalVar. "; - - const auto* tuple_args = call_node->args[1].as(); - ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple. "; - - ICHECK(call_node->attrs.defined()) << "Attributes for call_lowered should be defined!"; - const auto* attrs = call_node->attrs.as(); - ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " - << call_node->attrs->GetTypeKey(); - return CallLoweredProps{GetRef(function), tuple_args->fields, *attrs}; + if (call_node->op == CallLoweredOp()) { + ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 arguments."; + const auto* function_node = call_node->args[0].as(); + ICHECK(function_node) << "Expected first arg to call_lowered to be a GlobalVar. "; + + const auto* tuple_args = call_node->args[1].as(); + ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple of input arguments."; + + ICHECK(call_node->attrs.defined()) << "Expecting call_lowered to have attributes."; + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " + << call_node->attrs->GetTypeKey(); + return CallLoweredProps{GetRef(function_node), tuple_args->fields, *attrs}; + } + return {}; +} + +bool IsReshapeOnly(const CallLoweredProps& props) { + if (props.attrs.metadata.count("relay_attrs")) { + auto dict_attrs = Downcast(props.attrs.metadata["relay_attrs"]); + return dict_attrs.HasNonzeroAttr(attr::kReshapeOnly); + } + return false; } } // namespace relay diff --git a/src/relay/op/call/call.h b/src/relay/op/call/call.h index 381be6724e0d..1ff7ee482503 100644 --- a/src/relay/op/call/call.h +++ b/src/relay/op/call/call.h @@ -62,12 +62,16 @@ struct CallLoweredProps { }; /*! - * \brief Helper to extract the lowered function and its arguments from Call("call_lowered", ...). - * Will fail if called on a Call whose op is not "call_lowered" \param call_node CallNode that we - * want to get the function and its arguments from. + * \brief Helper to extract the lowered function and its arguments from a Call("call_lowered", ...). + * Returns the null/empty \p CallLoweredProps if \p call_node is not in that form. */ CallLoweredProps GetCallLoweredProps(const CallNode* call_node); +/*! + * \brief Returns true if lowered call described by \p props is to a reshape primitive. + */ +bool IsReshapeOnly(const CallLoweredProps& props); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index 48d12368fa28..6851a2e8bf58 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -24,6 +24,7 @@ #include "./device_copy.h" +#include #include #include #include @@ -116,5 +117,15 @@ DeviceCopyProps GetDeviceCopyProps(const Expr& expr) { return {}; } +DeviceCopyProps GetLoweredDeviceCopyProps(const CallLoweredProps& props) { + if (props.attrs.metadata.count("src_se_scope") == 1 && + props.attrs.metadata.count("dst_se_scope") == 1) { + ICHECK_EQ(props.arguments.size(), 1) << "device_copy is of arity 1"; + return {props.arguments[0], Downcast(props.attrs.metadata["src_se_scope"]), + Downcast(props.attrs.metadata["dst_se_scope"])}; + } + return {}; +} + } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h index 3b40f410e53b..b8606a3f0c9e 100644 --- a/src/relay/op/memory/device_copy.h +++ b/src/relay/op/memory/device_copy.h @@ -30,6 +30,8 @@ #include +#include "../call/call.h" + namespace tvm { namespace relay { @@ -77,6 +79,14 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node); */ DeviceCopyProps GetDeviceCopyProps(const Expr& expr); +/*! + * \brief As for GetDeviceCopyProps, but for a lowered call rather than the original + * "device_copy" operator. + * + * See te_compiler.cc for where this rewriting occurs. + */ +DeviceCopyProps GetLoweredDeviceCopyProps(const CallLoweredProps& props); + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 44c0ecf41de8..af75faea573d 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -36,30 +36,6 @@ namespace tvm { namespace relay { namespace transform { -namespace { - -/*! - * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather - * than the original "device_copy" operator. - * - * See te_compiler.cc for where this rewriting occurs. - */ -DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { - if (call_node->op == CallLoweredOp()) { - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - if (call_lowered_props.attrs.metadata.count("source_device") == 1 && - call_lowered_props.attrs.metadata.count("dst_device") == 1) { - ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1"; - return {call_lowered_props.arguments[0], - Downcast(call_lowered_props.attrs.metadata["src_se_scope"]), - Downcast(call_lowered_props.attrs.metadata["dst_se_scope"])}; - } - } - return {}; -} - -} // namespace - DeviceDomains::DeviceDomains(CompilationConfig config) : config_(std::move(config)) { host_domain_ = MakeFirstOrderDomain(config_->host_se_scope); } @@ -221,9 +197,10 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { OnDeviceProps on_device_props = GetOnDeviceProps(call.get()); DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get()); - if (!device_copy_props.body.defined()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); + if (!device_copy_props.body.defined() && call_lowered_props.lowered_func.defined()) { // Special case for the TIR-ified version of "device_copy". - device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); + device_copy_props = GetLoweredDeviceCopyProps(call_lowered_props); } if (on_device_props.body.defined()) { @@ -324,8 +301,7 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { args_and_result.emplace_back(param_domain); } args_and_result.emplace_back(result_domain); - } else if (call->op == CallLoweredOp()) { - CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); + } else if (call_lowered_props.lowered_func.defined()) { return DomainFor(call_lowered_props.lowered_func); } else { // We still need to handle the case where the function / op is not lowered diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 4a044584dccd..b613a03bfc5c 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -65,6 +65,8 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtrGetConstants(); }); } else if (name == "get_virtual_devices") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetVirtualDevices(); }); + } else if (name == "get_primitives") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetPrimitives(); }); } else if (name == "get_stats") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "save") { @@ -123,10 +125,14 @@ std::string Executable::GetBytecode() const { const auto& func = functions[i]; // Print the header of the function format. oss << "VM Function[" << i << "]: " << func.name << "("; + bool first = true; for (const auto& param : func.params) { - oss << param << ", "; + if (!first) { + oss << ", "; + } + oss << param; + first = false; } - oss.seekp(-2, std::ios_base::end); oss << ")" << std::endl; oss << "# reg file size = " << func.register_file_size << std::endl; oss << "# instruction count = " << func.instructions.size() << std::endl; @@ -137,10 +143,12 @@ std::string Executable::GetBytecode() const { for (size_t idx = 0; idx < func.instructions.size(); ++idx) { const auto& instr = func.instructions[idx]; const auto& serialized_instr = SerializeInstruction(instr); - oss << std::setw(2) << idx << ": " << serialized_instr.opcode << " "; + std::ostringstream line; + line << std::setw(2) << idx << ": " << serialized_instr.opcode << " "; for (auto it : serialized_instr.fields) { - oss << it << " "; + line << it << " "; } + oss << std::setw(40) << std::setfill(' ') << std::left << line.str(); oss << " # " << instr; if (oss.str().back() != '\n') oss << std::endl; } @@ -186,6 +194,23 @@ std::string Executable::GetVirtualDevices() const { return oss.str(); } +std::string Executable::GetPrimitives() const { + std::ostringstream os; + std::vector> entries; + entries.reserve(primitive_map.size()); + for (const auto& kv : primitive_map) { + entries.emplace_back(kv.second, kv.first); + } + std::sort(entries.begin(), entries.end(), + [](const std::pair& left, const std::pair& right) { + return left.first < right.first; + }); + for (const auto& entry : entries) { + os << "VM PackedFunc[" << entry.first << "]: " << entry.second << std::endl; + } + return os.str(); +} + std::string Executable::Stats() const { std::ostringstream oss; oss << "Relay VM executable statistics:" << std::endl; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 05adf1d69e8d..5d18b76a9d80 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -67,7 +67,9 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { inline ObjectRef CopyTo(ObjectRef src, const DLDevice& dev) { if (src->IsInstance()) { auto nd_array = Downcast(src); + // TODO(mbs): Should respect device id also. if (nd_array->device.device_type != dev.device_type) { + VLOG(2) << "copying from " << nd_array->device.device_type << " to " << dev.device_type; return nd_array.CopyTo(dev); } return src; @@ -712,6 +714,17 @@ void VirtualMachine::RunLoop() { int64_t ndim = shape_tensor->shape[0]; std::vector shape(dims, dims + ndim); // Reshape the input tensor +#if TVM_LOG_DEBUG + std::ostringstream os; + os << "ReshapeTensor: "; + os << "shape=["; + for (auto i : shape) { + os << i << ","; + } + os << "]"; + os << ", dtype=" << DLDataType2String(tensor_arr->dtype); + VLOG(2) << os.str(); +#endif auto out_tensor = tensor_arr.CreateView(shape, tensor_arr->dtype); WriteRegister(instr.dst, out_tensor); OpStopHook(); diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index 9797b6751af7..e9c4daf7bf05 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -111,8 +111,7 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex /*virtual_device_id=*/0, host_target)); // - // Now that we've settled on a host, make sure all the primitive Targets agree on it for - // their 'host' field. This mutates the primitives. + // Now that we've settled on a host, we can set it as the host on all primitive targets. // Array new_primitve_targets; new_primitve_targets.reserve(primitive_targets.size()); @@ -162,6 +161,7 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex if (name == "cpu") { if (runtime::Registry::Get("codegen.LLVMModuleCreate")) { // LLVM is available. + // TODO(mbs): More robust extension mechanism? return Target("llvm"); } else { // LLVM is not available. @@ -196,7 +196,7 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, VLOG(0) << "Available host target " << optional_host_target_arg->ToDebugString(); } - // Capture the arguments in our representation. + // Capture the arguments in our preferred representation. for (const auto& pair : legacy_target_map_arg) { node->primitive_targets.push_back(pair.second); } @@ -207,7 +207,9 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, // all primitive targets will have host target_host. node->EstablishDefaultSEScopes(pass_ctx); - // LEGACY: Reconstruct the target map with all the primitive targets. + // LEGACY: Reconstruct the target map from all the primitive targets. + // Note that we require pointer equality between targets in legacy_target_map and + // primitive_targets. for (const auto& primitive_target : node->primitive_targets) { node->legacy_target_map.Set(Integer(primitive_target->kind->device_type), primitive_target); } @@ -219,7 +221,7 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, // Legacy: Some passes only support homogenous compilation and expect the target to be // given by the global target context. Make this easy to detect. node->optional_homogeneous_target = - node->primitive_targets.size() == 1 ? *node->primitive_targets.begin() : Target(); + node->legacy_target_map.size() == 1 ? (*node->legacy_target_map.begin()).second : Target(); for (const auto& target : node->primitive_targets) { DLOG(INFO) << "Target " << target->ToDebugString() << " of device type " diff --git a/src/target/target.cc b/src/target/target.cc index 6f5e8ee67b30..792884061db6 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -558,7 +558,7 @@ String TargetNode::ToDebugString() const { if (!first) { os << ", "; } - os << '"' << pair.first << "': " << pair.second; + os << "'" << pair.first << "': " << pair.second; first = false; } os << "}"; diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc index ae5f5d0c3dc4..31b936807edc 100644 --- a/tests/cpp/target/compilation_config_test.cc +++ b/tests/cpp/target/compilation_config_test.cc @@ -81,9 +81,11 @@ TEST(CompilationConfig, Constructor_Homegenoous_InnerHost) { TEST(CompilationConfig, Constructor_Homogenous_CPUHost) { transform::PassContext pass_ctx = transform::PassContext::Create(); + Target host_target = TestCpuTarget(); Target cpu_target = TestCpuTarget(); TargetMap legacy_target_map; - legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + legacy_target_map.Set(Integer(static_cast(kDLCPU)), + Target::WithHost(cpu_target, host_target)); CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); EXPECT_TRUE(StructuralEqual()(config->host_target, cpu_target)); @@ -99,8 +101,10 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) { Target cuda_target = TestCudaTarget(); Target cpu_target = TestCpuTarget(); TargetMap legacy_target_map; - legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); - legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + legacy_target_map.Set(Integer(static_cast(kDLCPU)), + Target::WithHost(cpu_target, host_target)); + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), + Target::WithHost(cuda_target, host_target)); CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); SEScope expected_default_primitive_se_scope(kDLCUDA, 0, @@ -108,6 +112,13 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) { SEScope expected_host_se_scope(kDLCPU, 0, host_target); ASSERT_EQ(config->legacy_target_map.size(), 2); + for (const auto& pair : config->legacy_target_map) { + if (pair.first->value == kDLCPU) { + EXPECT_TRUE(StructuralEqual()(pair.second, Target::WithHost(cpu_target, host_target))); + } else if (pair.first->value == kDLCUDA) { + EXPECT_TRUE(StructuralEqual()(pair.second, Target::WithHost(cuda_target, host_target))); + } + } EXPECT_TRUE(config->host_target.defined()); EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); EXPECT_TRUE( @@ -123,8 +134,10 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) { Target cuda_target = TestCudaTarget(); Target cpu_target = TestCpuTarget(); TargetMap legacy_target_map; - legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); - legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + legacy_target_map.Set(Integer(static_cast(kDLCPU)), + Target::WithHost(cpu_target, host_target)); + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), + Target::WithHost(cuda_target, host_target)); CompilationConfig config(pass_ctx, legacy_target_map, host_target); SEScope expected_default_primitive_se_scope(kDLCUDA, 0, @@ -132,6 +145,13 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) { SEScope expected_host_se_scope(kDLCPU, 0, host_target); ASSERT_EQ(config->legacy_target_map.size(), 2); + for (const auto& pair : config->legacy_target_map) { + if (pair.first->value == kDLCPU) { + EXPECT_TRUE(StructuralEqual()(pair.second, Target::WithHost(cpu_target, host_target))); + } else if (pair.first->value == kDLCUDA) { + EXPECT_TRUE(StructuralEqual()(pair.second, Target::WithHost(cuda_target, host_target))); + } + } EXPECT_TRUE(config->host_target.defined()); EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); ASSERT_EQ(config->primitive_targets.size(), 2);