diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h index 2bad0d8da8ec1..276c64a0d7538 100644 --- a/include/tvm/ir/global_var_supply.h +++ b/include/tvm/ir/global_var_supply.h @@ -17,6 +17,10 @@ * under the License. */ +/*! + * \file tvm/ir/global_var_supply.h + * \brief GlobalVarSupply that can be used to generate unique \class GlobalVar. + */ #ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_ #define TVM_IR_GLOBAL_VAR_SUPPLY_H_ @@ -29,20 +33,51 @@ namespace tvm { +/*! + * \brief GlobalVarSupply can be used to generate unique GlobalVars. + */ class GlobalVarSupplyNode : public Object { public: + /*! + * \brief Empty constructor. Will use an empty NameSupply. + */ GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {} - explicit GlobalVarSupplyNode(NameSupply name_supply); - + /*! + * \brief Constructor. + * \param name_supply The NameSupply to use for generating the names of fresh GlobalVars. + * \param name_to_var_map An optional map. + */ + explicit GlobalVarSupplyNode(NameSupply name_supply, + std::unordered_map name_to_var_map = {}); + + /*! + * \brief Generates a unique GlobalVar from this supply. + * \param name The name from which the name of the GlobalVar is derived. + * \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended + * to the name. \return A unique GlobalVar. + */ GlobalVar FreshGlobal(String name, bool add_prefix = true); + /*! + * \brief Looks up for a GlobalVar with the given name in this supply. + * If no entry is found, creates one, places it in the cache and returns it. + * \param name The name of the GlobalVar to search for. + * \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to + * the name before performing the search. \return A cached GlobalVar. + */ GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true); + /*! + * \brief Reserves an existing GlobalVar with this supply. + * \param var The GlobalVar to be registered. + * \param allow_conflict Allow conflict with other GlobalVars that have the same name. + */ void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false); - void VisitAttrs(AttrVisitor* v) { v->Visit("name_supply", &name_supply_); } + void VisitAttrs(AttrVisitor* v) {} + /*! \brief The NameSupply used to generate unique name hints to GlobalVars. */ NameSupply name_supply_; static constexpr const char* _type_key = "GlobalVarSupply"; @@ -52,28 +87,37 @@ class GlobalVarSupplyNode : public Object { private: std::unordered_map name_to_var_map_; - - friend class GlobalVarSupply; }; +/*! + * \brief Managed reference class to GlobalVarSupplyNode. + * \sa GlobalVarSupplyNode + */ class GlobalVarSupply : public ObjectRef { public: - TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply(), + /*! + * \brief Constructor. + * \param name_supply The NameSupply to be used when generating new GlobalVars. + * \param name_to_var_map An optional map. + */ + TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply, std::unordered_map name_to_var_map = {}); + /*! + * \brief Constructs a supply from an array of IRModules. GlobalVars generated by this supply are + * guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array + * of IRModules. + */ TVM_DLL explicit GlobalVarSupply(const Array& modules); + /*! + * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are + * guaranteed not to conflict with GlobalVars that belong to the modules. \param module The + * IRModule. + */ TVM_DLL explicit GlobalVarSupply(const IRModule module); - explicit GlobalVarSupply(ObjectPtr n) : ObjectRef(n) {} - /*! \return mutable pointers to the node. */ - GlobalVarSupplyNode* operator->() const { - auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); - return static_cast(ptr); - } - - TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarSupplyNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, GlobalVarSupplyNode); }; } // namespace tvm diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 28cf43ae86845..7313b4f783492 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -480,7 +480,7 @@ namespace attr { * * \sa tvm::runtime::String */ -constexpr const char* kModuleName = "name"; +constexpr const char* kModuleName = "mod_name"; /*! * \brief Executor targeted by the module diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index 26bd1b3c1de81..a85a6fe70a66a 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -17,31 +17,66 @@ * under the License. */ +/*! + * \file tvm/ir/name_supply.h + * \brief NameSupply that can be used to generate unique variable names. + */ #ifndef TVM_IR_NAME_SUPPLY_H_ #define TVM_IR_NAME_SUPPLY_H_ #include #include +#include #include "tvm/ir/expr.h" namespace tvm { +/*! + * \brief NameSupply can be used to generate unique names. + */ class NameSupplyNode : public Object { public: - NameSupplyNode() : NameSupplyNode("") {} - - explicit NameSupplyNode(const String& prefix); - + /*! + * \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro. + */ + NameSupplyNode() = default; + + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this NameSupply. + * \param name_map The map used to guarantee uniqueness. + */ + NameSupplyNode(const String& prefix, std::unordered_map name_map) + : prefix_(prefix), name_map(std::move(name_map)) {} + + /*! + * \brief Generates a unique name from this NameSupply. + * \param name The name from which the generated name is derived. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name. \return A unique name. + */ String FreshName(const String& name, bool add_prefix = true); + /*! + * \brief Reserves an existing name with this NameSupply. + * \param name The name to be reserved. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name before reserving it. \return The name that was reserved with the NameSupply. It can be + * different if a prefix is added. + */ String ReserveName(const String& name, bool add_prefix = true); + /*! + * \brief Checks if this NameSupply already generated a name. + * \param name The name to check. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name before checking for it. \return True if the name has already been generated. False + * otherwise. + */ bool ContainsName(const String& name, bool add_prefix = true); - void Clear(); - - void VisitAttrs(AttrVisitor* v) { v->Visit("prefix", &prefix_); } + void VisitAttrs(AttrVisitor* v) {} // Prefix for all GlobalVar names. It can be empty. std::string prefix_; @@ -52,32 +87,35 @@ class NameSupplyNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object); private: - String prefix_module_name(const String& name); - + /*! \brief Helper function to add the NameSupply prefix to the name. */ + String add_prefix_to_name(const String& name); + + /*! + * \brief Function that will generate a unique name. + * \param name The name to be used as a base. + * \return A unique name. + */ std::string GetUniqueName(std::string name); - // Key is function_name. Value is a counter. + /*! \brief A map that is used to generate unique names. */ std::unordered_map name_map; - - friend class NameSupply; }; +/*! + * \brief Managed reference class to NameSupplyNode. + * \sa NameSupplyNode + */ class NameSupply : public ObjectRef { public: - TVM_DLL explicit NameSupply(); - + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this NameSupply. + * \param name_map An optional map. + */ TVM_DLL explicit NameSupply(const String& prefix, std::unordered_map name_map = {}); - explicit NameSupply(ObjectPtr n) : ObjectRef(n) {} - /*! \return mutable pointers to the node. */ - NameSupplyNode* operator->() const { - auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); - return static_cast(ptr); - } - - TVM_DEFINE_OBJECT_REF_COW_METHOD(NameSupplyNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode); }; } // namespace tvm diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index cd4290daefa63..095ac43c03b85 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -33,23 +33,52 @@ def __init__(self, prefix=""): self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix) def fresh_name(self, name, add_prefix=True): + """Generates a unique name from this NameSupply. + + Parameters + ---------- + name: String + The name from which the generated name is derived. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name. + """ return _ffi_api.NameSupply_FreshName(self, name, add_prefix) def reserve_name(self, name, add_prefix=True): + """Reserves an existing name with this NameSupply. + + Parameters + ---------- + name: String + The name to be reserved. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name + before reserving it. + """ return _ffi_api.NameSupply_ReserveName(self, name, add_prefix) def contains_name(self, name, add_prefix=True): - return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) + """Checks if this NameSupply already generated a name. - def clear(self): - return _ffi_api.NameSupply_Clear(self) + Parameters + ---------- + name: String + The name to check. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name + before checking for it. + """ + return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) @tvm._ffi.register_object("GlobalVarSupply") class GlobalVarSupply(Object): """GlobalVarSupply that holds a mapping between names and GlobalVars. - GlobalVarSupply can be used to generate new GlobalVars with an unique name. + GlobalVarSupply can be used to generate new GlobalVars with a unique name. It also can be used to retrieve previously generated GlobalVars based on a name. Parameters @@ -70,10 +99,43 @@ def __init__(self, value=None): self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value) def fresh_global(self, name, add_prefix=True): + """Generates a unique GlobalVar from this supply. + + Parameters + ---------- + name: String + The name from which the name of the GlobalVar is derived. + + add_prefix: bool + If set to true, then the prefix of the contained NameSupply will be prepended + to the name. + """ return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix) def unique_global_for(self, name, add_prefix=True): + """Looks up for a GlobalVar with the given name in this supply. If no entry is found + , creates one, places it in the cache and returns it. + + Parameters + ---------- + name: String + The name of the GlobalVar to search for. + + add_prefix: bool + If set to true, the prefix of the contained NameSupply will be prepended to the + name before performing the search. + """ return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix) def reserve_global(self, global_var, allow_conflict=False): + """Reserves an existing GlobalVar with this supply. + + Parameters + ---------- + global_var: GlobalVar + The GlobalVar to be registered. + + allow_conflict: bool + Allow conflict with other GlobalVars that have the same name + """ return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, allow_conflict) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index eab5c9ec2e7c2..c930bf0c4e73b 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1372,7 +1372,8 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i auto pass_ctx = tvm::transform::PassContext::Current(); auto mod = ScheduleToModule(sch, Array{tensors.begin(), tensors.end()}, name, - std::unordered_map(), GlobalVarSupply()); + std::unordered_map(), + GlobalVarSupply(NameSupply(""))); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index bbd83c9500e76..53026c7fc3b30 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -147,7 +147,7 @@ class CodeGenHybrid : public ExprFunctor, /*! \brief Print the current indent spaces. */ inline void PrintIndent(); /*! \brief NameSupply for allocated ids. */ - NameSupply ids_allocated = NameSupply(); + NameSupply ids_allocated = NameSupply(""); /*! * \brief Keys are either (tensors, value_index) or (variables, 0). * Values are the corresponding IDs.*/ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d95e7fca24422..25fdeb6da1c2b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -303,7 +303,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") c_binds.insert({kv.first, kv.second}); } } - IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply()); + IRModule mod = + ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply(""))); return mod; }); @@ -366,7 +367,8 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") c_binds.insert({kv.first, kv.second}); } } - return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(), simple_mode); + return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")), + simple_mode); }); /** diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index f1a736e8d0fab..383d4445adcf8 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -17,6 +17,10 @@ * under the License. */ +/*! + * \file global_var_supply.cc + * \brief GlobalVarSupply that can be used to generate unique GlobalVars. + */ #include "tvm/ir/global_var_supply.h" #include @@ -28,8 +32,7 @@ namespace tvm { GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply, std::unordered_map name_to_var_map) { - auto n = make_object(name_supply); - n->name_to_var_map_ = std::move(name_to_var_map); + auto n = make_object(name_supply, name_to_var_map); data_ = std::move(n); } @@ -37,7 +40,7 @@ std::string GetModuleName(const IRModule& module) { return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); } -GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupply() { +GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupply(NameSupply("")) { if (!modules.empty()) { IRModule first_mod = modules.front(); this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod); @@ -61,8 +64,9 @@ void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conf name_to_var_map_[var->name_hint] = var; } -GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply) - : name_supply_(std::move(name_supply)) {} +GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply, + std::unordered_map name_to_var_map) + : name_supply_(std::move(name_supply)), name_to_var_map_(std::move(name_to_var_map)) {} GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_prefix) { String final_name = name_supply_->ReserveName(name, add_prefix); @@ -72,7 +76,7 @@ GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_pref return it->second; } else { GlobalVar var = GlobalVar(final_name); - name_to_var_map_[final_name] = var; + name_to_var_map_.emplace(final_name, var); return var; } } @@ -82,7 +86,7 @@ GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) << "GlobalVar already exists for name " << final_name; GlobalVar var = GlobalVar(final_name); - name_to_var_map_[final_name] = var; + name_to_var_map_.emplace(final_name, var); return var; } diff --git a/src/ir/module.cc b/src/ir/module.cc index 25be477898c0c..8d6de5a536a70 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -21,6 +21,7 @@ * \file module.cc * \brief The global module in Relay. */ +#include #include #include #include @@ -40,8 +41,6 @@ #include #include -#include "tvm/ir/global_var_supply.h" - namespace tvm { IRModule::IRModule(tvm::Map functions, diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc index 9fd6df49d55d8..93f568253cba7 100644 --- a/src/ir/name_supply.cc +++ b/src/ir/name_supply.cc @@ -17,6 +17,10 @@ * under the License. */ +/*! + * \file name_supply.cc + * \brief NameSupply that can be used to generate unique variable names. + */ #include "tvm/ir/name_supply.h" #include @@ -25,20 +29,15 @@ namespace tvm { -NameSupply::NameSupply() : NameSupply("") {} - NameSupply::NameSupply(const String& prefix, std::unordered_map name_map) { - auto n = make_object(prefix); - n->name_map = std::move(name_map); + auto n = make_object(prefix, std::move(name_map)); data_ = std::move(n); } -NameSupplyNode::NameSupplyNode(const String& prefix) : prefix_(prefix) {} - String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { String final_name = name; if (add_prefix) { - final_name = prefix_module_name(name); + final_name = add_prefix_to_name(name); } name_map[final_name] = 0; return final_name; @@ -47,7 +46,7 @@ String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { String NameSupplyNode::FreshName(const String& name, bool add_prefix) { String unique_name = name; if (add_prefix) { - unique_name = prefix_module_name(name); + unique_name = add_prefix_to_name(name); } unique_name = GetUniqueName(unique_name); return unique_name; @@ -56,43 +55,39 @@ String NameSupplyNode::FreshName(const String& name, bool add_prefix) { bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) { String unique_name = name; if (add_prefix) { - unique_name = prefix_module_name(name); + unique_name = add_prefix_to_name(name); } return name_map.count(unique_name); } -void NameSupplyNode::Clear() { name_map.clear(); } - -String NameSupplyNode::prefix_module_name(const String& name) { +String NameSupplyNode::add_prefix_to_name(const String& name) { if (prefix_.empty()) { return name; } - std::stringstream ss; + std::ostringstream ss; ICHECK(name.defined()); ss << prefix_ << "_" << name; return ss.str(); } -std::string NameSupplyNode::GetUniqueName(std::string prefix) { - for (size_t i = 0; i < prefix.size(); ++i) { - if (prefix[i] == '.') prefix[i] = '_'; +std::string NameSupplyNode::GetUniqueName(std::string name) { + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == '.') name[i] = '_'; } - auto it = name_map.find(prefix); + auto it = name_map.find(name); if (it != name_map.end()) { - while (true) { + auto new_name = name; + while (!name_map.insert({new_name, 0}).second) { std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (name_map.count(name) == 0) { - prefix = name; - break; - } + os << name << "_" << (++it->second); + new_name = os.str(); } + return new_name; } - name_map[prefix] = 0; - return prefix; + name_map[name] = 0; + return name; } TVM_REGISTER_NODE_TYPE(NameSupplyNode); @@ -110,6 +105,4 @@ TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName") TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName") .set_body_method(&NameSupplyNode::ContainsName); -TVM_REGISTER_GLOBAL("ir.NameSupply_Clear").set_body_method(&NameSupplyNode::Clear); - } // namespace tvm diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 574bda77f92b7..ab725d82e6760 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -630,7 +630,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator function_metadata_; /*! \brief NameSupply */ - NameSupply name_supply_ = NameSupply(); + NameSupply name_supply_ = NameSupply(""); }; class GraphExecutorCodegenModule : public runtime::ModuleNode { diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 9e53db77980bd..c577e8e356d64 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -74,7 +74,7 @@ Array ExtractTask( }); // Tasks are extracted via post order visit, return the reversed list. std::reverse(tasks.begin(), tasks.end()); - NameSupply name_supply = NameSupply(); + NameSupply name_supply = NameSupply(""); for (ExtractedTask task : tasks) { task->task_name = name_supply->FreshName(task->task_name); } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index b55cce643e92c..5c79ed2070cc6 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -136,37 +136,32 @@ TVM_REGISTER_OBJECT_TYPE(TECompilerNode); class TECompilerImpl : public TECompilerNode { public: explicit TECompilerImpl(Optional opt_mod, Optional opt_mod_name) { - String mod_name; - if (!opt_mod_name) { - mod_name = ""; - } else { - mod_name = opt_mod_name.value(); - } - NameSupply name_supply = NameSupply(mod_name); - global_var_supply = GlobalVarSupply(name_supply); + String mod_name = opt_mod_name.value_or(""); + NameSupply name_supply = NameSupply(mod_name /* prefix */); + global_var_supply_ = GlobalVarSupply(name_supply); // Make sure we don't collide with any existing globals in the module. if (opt_mod) { for (const auto& kv : opt_mod.value()->functions) { - global_var_supply->name_supply_->ReserveName(kv.first->name_hint, false); + global_var_supply_->name_supply_->ReserveName(kv.first->name_hint, false); } } } // Lower the function. CachedFunc Lower(const CCacheKey& key) { - return LowerInternal(key, global_var_supply)->cached_func; + return LowerInternal(key, global_var_supply_)->cached_func; } - // TODO(gigiblender): Only to be called by the GlobalTECompiler. - // Remove this when the GlobalTECompiler is removed. + // TODO(gigiblender): Only to be called by the global TE compiler. + // Remove this when the global TE compiler is removed. CachedFunc Lower(const CCacheKey& key, const String mod_name) { - global_var_supply->name_supply_->prefix_ = mod_name; - return LowerInternal(key, global_var_supply)->cached_func; + global_var_supply_->name_supply_->prefix_ = mod_name; + return LowerInternal(key, global_var_supply_)->cached_func; } // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { - CCacheValue value = LowerInternal(key, GlobalVarSupply()); + CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply(""))); if (value->packed_func != nullptr) { return value->packed_func; } @@ -499,7 +494,7 @@ class TECompilerImpl : public TECompilerNode { using tvm::transform::PassContext; With fresh_pass_ctx_scope(PassContext::Create()); - value->cached_func = ShapeFuncFor(key->source_func, key->target, global_var_supply); + value->cached_func = ShapeFuncFor(key->source_func, key->target, global_var_supply_); ICHECK( value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var).as()); @@ -527,7 +522,7 @@ class TECompilerImpl : public TECompilerNode { /*! \brief compiler cache lock*/ std::mutex mutex_; /*! \brief internal GlobalVarSupply to get unique GlobalVars */ - GlobalVarSupply global_var_supply; + GlobalVarSupply global_var_supply_; /*! \brief internal compiler cache */ std::unordered_map cache_; /*! \brief internal compiler cache for shape funcs */ diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 07ebc68974634..25111cec8eda1 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -126,7 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply()); + tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply(NameSupply(""))); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/meta_schedule_layout_rewrite.cc b/src/relay/transforms/meta_schedule_layout_rewrite.cc index a6682f2cf4898..8a70f224c611e 100644 --- a/src/relay/transforms/meta_schedule_layout_rewrite.cc +++ b/src/relay/transforms/meta_schedule_layout_rewrite.cc @@ -127,7 +127,7 @@ Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode* call) { if (const auto* func = call->op.as()) { LayoutIndexQueue* self = LayoutIndexQueue::Global(); self->queue_.clear(); - tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply()); + tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply(NameSupply(""))); if (!self->queue_.empty()) { std::deque queue = std::move(self->queue_); self->queue_.clear(); diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 66c56ad4222f0..75833fd93629d 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -28,7 +28,7 @@ namespace tvm { namespace codegen { void CodeGenSourceBase::ClearFuncState() { - name_supply_->Clear(); + name_supply_ = NameSupply(""); ssa_assign_map_.clear(); var_idmap_.clear(); scope_mark_.clear(); diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 4a764eeb23317..2fd0abcd68a63 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -123,7 +123,7 @@ class CodeGenSourceBase { /*! \brief name of each variable */ std::unordered_map var_idmap_; /*! \brief NameSupply for allocation */ - NameSupply name_supply_ = NameSupply(); + NameSupply name_supply_ = NameSupply(""); private: /*! \brief assignment map of ssa */ diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index ccb7a56cc8daa..6521246ff41f5 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -63,7 +63,7 @@ struct CreateFuncInfo { /*! \brief The buffers should be allocated at function root. */ Array root_alloc; /*! \brief The NameSupply to make block name unique. */ - NameSupply name_supply = NameSupply(); + NameSupply name_supply = NameSupply(""); String FreshName(String base_name) { return name_supply->FreshName(base_name); } diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 181a1fa3de4cb..3d2adb235546a 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -52,7 +52,7 @@ TEST(BuildModule, Basic) { auto target = Target("llvm"); - auto lowered = LowerSchedule(s, args, "func", binds, GlobalVarSupply()); + auto lowered = LowerSchedule(s, args, "func", binds, GlobalVarSupply(NameSupply(""))); auto module = build(lowered, target, Target()); auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali"); @@ -121,7 +121,7 @@ TEST(BuildModule, Heterogeneous) { auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - GlobalVarSupply global_var_supply = GlobalVarSupply(); + GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply("")); auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds, global_var_supply); auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds, global_var_supply); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; diff --git a/tests/cpp/c_codegen_test.cc b/tests/cpp/c_codegen_test.cc index 8dbd7d95e3c4a..442f76a8cff3c 100644 --- a/tests/cpp/c_codegen_test.cc +++ b/tests/cpp/c_codegen_test.cc @@ -52,7 +52,8 @@ TEST(CCodegen, MainFunctionOrder) { auto args = Array({A, B, elemwise_add}); std::unordered_map binds; - auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds, GlobalVarSupply()); + auto lowered = + LowerSchedule(fcreate(), args, "elemwise_add", binds, GlobalVarSupply(NameSupply(""))); Map inputs = {{target_c, lowered}}; runtime::Module module = build(inputs, Target()); Array functions = module->GetFunction("get_func_names", false)(); @@ -81,7 +82,8 @@ auto BuildLowered(std::string op_name, tvm::Target target) { auto args = Array({A, B, op}); std::unordered_map binds; - auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds, GlobalVarSupply()); + auto lowered_s = + LowerSchedule(fcreate_s(), args, op_name, binds, GlobalVarSupply(NameSupply(""))); return lowered_s; } diff --git a/tests/cpp/name_supply_test.cc b/tests/cpp/name_supply_test.cc index 1bf167b03a871..75b9ae86a9ab9 100644 --- a/tests/cpp/name_supply_test.cc +++ b/tests/cpp/name_supply_test.cc @@ -36,14 +36,29 @@ TEST(NameSupply, FreshName) { NameSupply name_supply = preambleNameSupply(); String fresh = name_supply->FreshName("test"); - EXPECT_EQ(fresh.compare("prefix_test1"), 0); + EXPECT_EQ(fresh.compare("prefix_test_1"), 0); +} + +TEST(NameSupply, FreshNameNoConflict) { + NameSupply name_supply = preambleNameSupply(); + String fresh = name_supply->FreshName("name_2"); + EXPECT_EQ(fresh.compare("prefix_name_2"), 0); + + fresh = name_supply->FreshName("name"); + EXPECT_EQ(fresh.compare("prefix_name"), 0); + + fresh = name_supply->FreshName("name"); + EXPECT_EQ(fresh.compare("prefix_name_1"), 0); + + fresh = name_supply->FreshName("name"); + EXPECT_EQ(fresh.compare("prefix_name_3"), 0); } TEST(NameSupply, ContainsName) { NameSupply name_supply = preambleNameSupply(); EXPECT_TRUE(name_supply->ContainsName("test")); - EXPECT_FALSE(name_supply->ContainsName("test1")); + EXPECT_FALSE(name_supply->ContainsName("test_1")); } TEST(NameSupply, ReserveName) { @@ -52,10 +67,14 @@ TEST(NameSupply, ReserveName) { EXPECT_TRUE(name_supply->ContainsName("otherTest", false)); EXPECT_FALSE(name_supply->ContainsName("otherTest")); + + name_supply->ReserveName("otherTest"); + EXPECT_TRUE(name_supply->ContainsName("prefix_otherTest", false)); + EXPECT_TRUE(name_supply->ContainsName("otherTest")); } GlobalVarSupply preambleVarSupply() { - GlobalVarSupply global_var_supply = GlobalVarSupply(); + GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply("")); global_var_supply->FreshGlobal("test"); return global_var_supply; } @@ -66,8 +85,8 @@ TEST(GlobalVarSupply, FreshGlobal) { GlobalVar second_var = global_var_supply->FreshGlobal("test"); EXPECT_FALSE(tvm::StructuralEqual()(first_var, second_var)); - EXPECT_EQ(first_var->name_hint.compare("test1"), 0); - EXPECT_EQ(second_var->name_hint.compare("test2"), 0); + EXPECT_EQ(first_var->name_hint.compare("test_1"), 0); + EXPECT_EQ(second_var->name_hint.compare("test_2"), 0); } TEST(GlobalVarSupply, UniqueGlobalFor) { @@ -90,7 +109,7 @@ TEST(GlobalVarSupply, ReserveGlobal) { EXPECT_TRUE(tvm::StructuralEqual()(var, second_var)); EXPECT_FALSE(tvm::StructuralEqual()(var, third_var)); EXPECT_EQ(second_var->name_hint.compare("someName"), 0); - EXPECT_EQ(third_var->name_hint.compare("someName1"), 0); + EXPECT_EQ(third_var->name_hint.compare("someName_1"), 0); } TEST(GlobalVarSupply, BuildIRModule) { @@ -106,5 +125,5 @@ TEST(GlobalVarSupply, BuildIRModule) { EXPECT_TRUE(tvm::StructuralEqual()(var, second_var)); EXPECT_FALSE(tvm::StructuralEqual()(var, third_var)); EXPECT_EQ(second_var->name_hint.compare("test"), 0); - EXPECT_EQ(third_var->name_hint.compare("test1"), 0); + EXPECT_EQ(third_var->name_hint.compare("test_1"), 0); } diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py index 07282042b0522..b7d012ca04d6c 100644 --- a/tests/python/unittest/test_meta_schedule_multi_anchor.py +++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py @@ -47,7 +47,7 @@ def get_ref(data_np, weight1_np, weight2_np): def schedule_dense_dense(sch): dense1 = sch.get_block("T_matmul_NT") - dense2 = sch.get_block("T_matmul_NT1") + dense2 = sch.get_block("T_matmul_NT_1") _y1, _x1, _k1 = sch.get_loops(dense1) _y2, _x2, _k2 = sch.get_loops(dense2) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 503a44cd4bb3d..d3f444ec081f1 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -29,7 +29,7 @@ def test_unique_name_complete_block(): func = te.create_prim_func([A, C]) s = tir.Schedule(func, debug_mask="all") assert isinstance(s.get_sref(s.get_block("main")), tir.schedule.StmtSRef) - assert isinstance(s.get_sref(s.get_block("main1")), tir.schedule.StmtSRef) + assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef) def test_unique_name_reduction_block(): @@ -41,7 +41,7 @@ def test_unique_name_reduction_block(): func = te.create_prim_func([A, C]) s = tir.Schedule(func, debug_mask="all") assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef) - assert isinstance(s.get_sref(s.get_block("sum1")), tir.schedule.StmtSRef) + assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef) def _check_workload(te_workload, tir_workload):