diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h index 2bad0d8da8ec1..608690131cba2 100644 --- a/include/tvm/ir/global_var_supply.h +++ b/include/tvm/ir/global_var_supply.h @@ -17,6 +17,10 @@ * under the License. */ +/*! + * \file tvm/ir/global_var_supply.h + * \brief GlobalVarSupply that can be used to generate unique \class GlobalVar. + */ #ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_ #define TVM_IR_GLOBAL_VAR_SUPPLY_H_ @@ -29,20 +33,49 @@ namespace tvm { +/*! + * \brief GlobalVarSupply can be used to generate unique GlobalVars. + */ class GlobalVarSupplyNode : public Object { public: + /*! + * \brief Empty constructor. Will use an empty NameSupply. + */ GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {} + /*! + * \brief Constructor. + * \param name_supply The NameSupply to use for generating the names of fresh GlobalVars. + */ explicit GlobalVarSupplyNode(NameSupply name_supply); + /*! + * \brief Generates a unique GlobalVar from this supply. + * \param name The name from which the name of the GlobalVar is derived. + * \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended + * to the name. \return A unique GlobalVar. + */ GlobalVar FreshGlobal(String name, bool add_prefix = true); + /*! + * \brief Looks up for a GlobalVar with the given name in this supply. + * If no entry is found, creates one, places it in the cache and returns it. + * \param name The name of the GlobalVar to search for. + * \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to + * the name before performing the search. \return A cached GlobalVar. + */ GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true); + /*! + * \brief Reserves an existing GlobalVar with this supply. + * \param var The GlobalVar to be registered. + * \param allow_conflict Allow conflict with other GlobalVars that have the same name. + */ void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false); void VisitAttrs(AttrVisitor* v) { v->Visit("name_supply", &name_supply_); } + /*! \brief The NameSupply used to generate unique name hints to GlobalVars. */ NameSupply name_supply_; static constexpr const char* _type_key = "GlobalVarSupply"; @@ -56,24 +89,35 @@ class GlobalVarSupplyNode : public Object { friend class GlobalVarSupply; }; +/*! + * \brief Managed reference class to GlobalVarSupplyNode. + * \sa GlobalVarSupplyNode + */ class GlobalVarSupply : public ObjectRef { public: - TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply(), + /*! + * \brief Constructor. + * \param name_supply The NameSupply to be used when generating new GlobalVars. + * \param name_to_var_map An optional map. + */ + TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply, std::unordered_map name_to_var_map = {}); + /*! + * \brief Constructs a supply from an array of IRModules. GlobalVars generated by this supply are + * guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array + * of IRModules. + */ TVM_DLL explicit GlobalVarSupply(const Array& modules); + /*! + * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are + * guaranteed not to conflict with GlobalVars that belong to the modules. \param module The + * IRModule. + */ TVM_DLL explicit GlobalVarSupply(const IRModule module); - explicit GlobalVarSupply(ObjectPtr n) : ObjectRef(n) {} - /*! \return mutable pointers to the node. */ - GlobalVarSupplyNode* operator->() const { - auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); - return static_cast(ptr); - } - - TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarSupplyNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, GlobalVarSupplyNode); }; } // namespace tvm diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 28cf43ae86845..7313b4f783492 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -480,7 +480,7 @@ namespace attr { * * \sa tvm::runtime::String */ -constexpr const char* kModuleName = "name"; +constexpr const char* kModuleName = "mod_name"; /*! * \brief Executor targeted by the module diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index 26bd1b3c1de81..303c06caaa3c7 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -17,6 +17,10 @@ * under the License. */ +/*! + * \file tvm/ir/name_supply.h + * \brief NameSupply that can be used to generate unique variable names. + */ #ifndef TVM_IR_NAME_SUPPLY_H_ #define TVM_IR_NAME_SUPPLY_H_ @@ -27,20 +31,37 @@ namespace tvm { +/*! + * \brief NameSupply can be used to generate unique names. + */ class NameSupplyNode : public Object { public: - NameSupplyNode() : NameSupplyNode("") {} - - explicit NameSupplyNode(const String& prefix); - + /*! + * \brief Generates a unique name from this NameSupply. + * \param name The name from which the generated name is derived. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name. \return A unique name. + */ String FreshName(const String& name, bool add_prefix = true); + /*! + * \brief Reserves an existing name with this NameSupply. + * \param name The name to be reserved. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name before reserving it. \return The name that was reserved with the NameSupply. It can be + * different if a prefix is added. + */ String ReserveName(const String& name, bool add_prefix = true); + /*! + * \brief Checks if this NameSupply already generated a name. + * \param name The name to check. + * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the + * name before checking for it. \return True if the name has already been generated. False + * otherwise. + */ bool ContainsName(const String& name, bool add_prefix = true); - void Clear(); - void VisitAttrs(AttrVisitor* v) { v->Visit("prefix", &prefix_); } // Prefix for all GlobalVar names. It can be empty. @@ -52,32 +73,37 @@ class NameSupplyNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object); private: - String prefix_module_name(const String& name); - + /*! \brief Helper function to add the NameSupply prefix to the name. */ + String add_prefix_to_name(const String& name); + + /*! + * \brief Function that will generate a unique name. + * \param name The name to be used as a base. + * \return A unique name. + */ std::string GetUniqueName(std::string name); - // Key is function_name. Value is a counter. + /*! \brief A map that is used to generate unique names. */ std::unordered_map name_map; friend class NameSupply; }; +/*! + * \brief Managed reference class to NameSupplyNode. + * \sa NameSupplyNode + */ class NameSupply : public ObjectRef { public: - TVM_DLL explicit NameSupply(); - + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this NameSupply. + * \param name_map An optional map. + */ TVM_DLL explicit NameSupply(const String& prefix, std::unordered_map name_map = {}); - explicit NameSupply(ObjectPtr n) : ObjectRef(n) {} - /*! \return mutable pointers to the node. */ - NameSupplyNode* operator->() const { - auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); - return static_cast(ptr); - } - - TVM_DEFINE_OBJECT_REF_COW_METHOD(NameSupplyNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode); }; } // namespace tvm diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index cd4290daefa63..095ac43c03b85 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -33,23 +33,52 @@ def __init__(self, prefix=""): self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix) def fresh_name(self, name, add_prefix=True): + """Generates a unique name from this NameSupply. + + Parameters + ---------- + name: String + The name from which the generated name is derived. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name. + """ return _ffi_api.NameSupply_FreshName(self, name, add_prefix) def reserve_name(self, name, add_prefix=True): + """Reserves an existing name with this NameSupply. + + Parameters + ---------- + name: String + The name to be reserved. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name + before reserving it. + """ return _ffi_api.NameSupply_ReserveName(self, name, add_prefix) def contains_name(self, name, add_prefix=True): - return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) + """Checks if this NameSupply already generated a name. - def clear(self): - return _ffi_api.NameSupply_Clear(self) + Parameters + ---------- + name: String + The name to check. + + add_prefix: bool + If set to true, then the prefix of this NameSupply will be prepended to the name + before checking for it. + """ + return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) @tvm._ffi.register_object("GlobalVarSupply") class GlobalVarSupply(Object): """GlobalVarSupply that holds a mapping between names and GlobalVars. - GlobalVarSupply can be used to generate new GlobalVars with an unique name. + GlobalVarSupply can be used to generate new GlobalVars with a unique name. It also can be used to retrieve previously generated GlobalVars based on a name. Parameters @@ -70,10 +99,43 @@ def __init__(self, value=None): self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value) def fresh_global(self, name, add_prefix=True): + """Generates a unique GlobalVar from this supply. + + Parameters + ---------- + name: String + The name from which the name of the GlobalVar is derived. + + add_prefix: bool + If set to true, then the prefix of the contained NameSupply will be prepended + to the name. + """ return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix) def unique_global_for(self, name, add_prefix=True): + """Looks up for a GlobalVar with the given name in this supply. If no entry is found + , creates one, places it in the cache and returns it. + + Parameters + ---------- + name: String + The name of the GlobalVar to search for. + + add_prefix: bool + If set to true, the prefix of the contained NameSupply will be prepended to the + name before performing the search. + """ return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix) def reserve_global(self, global_var, allow_conflict=False): + """Reserves an existing GlobalVar with this supply. + + Parameters + ---------- + global_var: GlobalVar + The GlobalVar to be registered. + + allow_conflict: bool + Allow conflict with other GlobalVars that have the same name + """ return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, allow_conflict) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index eab5c9ec2e7c2..c930bf0c4e73b 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1372,7 +1372,8 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i auto pass_ctx = tvm::transform::PassContext::Current(); auto mod = ScheduleToModule(sch, Array{tensors.begin(), tensors.end()}, name, - std::unordered_map(), GlobalVarSupply()); + std::unordered_map(), + GlobalVarSupply(NameSupply(""))); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index bbd83c9500e76..53026c7fc3b30 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -147,7 +147,7 @@ class CodeGenHybrid : public ExprFunctor, /*! \brief Print the current indent spaces. */ inline void PrintIndent(); /*! \brief NameSupply for allocated ids. */ - NameSupply ids_allocated = NameSupply(); + NameSupply ids_allocated = NameSupply(""); /*! * \brief Keys are either (tensors, value_index) or (variables, 0). * Values are the corresponding IDs.*/ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d95e7fca24422..25fdeb6da1c2b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -303,7 +303,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") c_binds.insert({kv.first, kv.second}); } } - IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply()); + IRModule mod = + ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply(""))); return mod; }); @@ -366,7 +367,8 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") c_binds.insert({kv.first, kv.second}); } } - return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(), simple_mode); + return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")), + simple_mode); }); /** diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index f1a736e8d0fab..3d599b93b5877 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -17,6 +17,10 @@ * under the License. */ +/*! + * \file global_var_supply.cc + * \brief GlobalVarSupply that can be used to generate unique GlobalVars. + */ #include "tvm/ir/global_var_supply.h" #include @@ -37,7 +41,7 @@ std::string GetModuleName(const IRModule& module) { return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); } -GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupply() { +GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupply(NameSupply("")) { if (!modules.empty()) { IRModule first_mod = modules.front(); this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod); diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc index 9fd6df49d55d8..cc828cebfa0d7 100644 --- a/src/ir/name_supply.cc +++ b/src/ir/name_supply.cc @@ -17,6 +17,10 @@ * under the License. */ +/*! + * \file name_supply.cc + * \brief NameSupply that can be used to generate unique variable names. + */ #include "tvm/ir/name_supply.h" #include @@ -25,20 +29,17 @@ namespace tvm { -NameSupply::NameSupply() : NameSupply("") {} - NameSupply::NameSupply(const String& prefix, std::unordered_map name_map) { - auto n = make_object(prefix); + auto n = make_object(); + n->prefix_ = prefix; n->name_map = std::move(name_map); data_ = std::move(n); } -NameSupplyNode::NameSupplyNode(const String& prefix) : prefix_(prefix) {} - String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { String final_name = name; if (add_prefix) { - final_name = prefix_module_name(name); + final_name = add_prefix_to_name(name); } name_map[final_name] = 0; return final_name; @@ -47,7 +48,7 @@ String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { String NameSupplyNode::FreshName(const String& name, bool add_prefix) { String unique_name = name; if (add_prefix) { - unique_name = prefix_module_name(name); + unique_name = add_prefix_to_name(name); } unique_name = GetUniqueName(unique_name); return unique_name; @@ -56,20 +57,18 @@ String NameSupplyNode::FreshName(const String& name, bool add_prefix) { bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) { String unique_name = name; if (add_prefix) { - unique_name = prefix_module_name(name); + unique_name = add_prefix_to_name(name); } return name_map.count(unique_name); } -void NameSupplyNode::Clear() { name_map.clear(); } - -String NameSupplyNode::prefix_module_name(const String& name) { +String NameSupplyNode::add_prefix_to_name(const String& name) { if (prefix_.empty()) { return name; } - std::stringstream ss; + std::ostringstream ss; ICHECK(name.defined()); ss << prefix_ << "_" << name; return ss.str(); @@ -110,6 +109,4 @@ TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName") TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName") .set_body_method(&NameSupplyNode::ContainsName); -TVM_REGISTER_GLOBAL("ir.NameSupply_Clear").set_body_method(&NameSupplyNode::Clear); - } // namespace tvm diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 574bda77f92b7..ab725d82e6760 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -630,7 +630,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator function_metadata_; /*! \brief NameSupply */ - NameSupply name_supply_ = NameSupply(); + NameSupply name_supply_ = NameSupply(""); }; class GraphExecutorCodegenModule : public runtime::ModuleNode { diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 9e53db77980bd..c577e8e356d64 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -74,7 +74,7 @@ Array ExtractTask( }); // Tasks are extracted via post order visit, return the reversed list. std::reverse(tasks.begin(), tasks.end()); - NameSupply name_supply = NameSupply(); + NameSupply name_supply = NameSupply(""); for (ExtractedTask task : tasks) { task->task_name = name_supply->FreshName(task->task_name); } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index b55cce643e92c..e57b9e102eb81 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -142,7 +142,7 @@ class TECompilerImpl : public TECompilerNode { } else { mod_name = opt_mod_name.value(); } - NameSupply name_supply = NameSupply(mod_name); + NameSupply name_supply = NameSupply(mod_name /* prefix */); global_var_supply = GlobalVarSupply(name_supply); // Make sure we don't collide with any existing globals in the module. if (opt_mod) { @@ -157,8 +157,8 @@ class TECompilerImpl : public TECompilerNode { return LowerInternal(key, global_var_supply)->cached_func; } - // TODO(gigiblender): Only to be called by the GlobalTECompiler. - // Remove this when the GlobalTECompiler is removed. + // TODO(gigiblender): Only to be called by the global TE compiler. + // Remove this when the global TE compiler is removed. CachedFunc Lower(const CCacheKey& key, const String mod_name) { global_var_supply->name_supply_->prefix_ = mod_name; return LowerInternal(key, global_var_supply)->cached_func; @@ -166,7 +166,7 @@ class TECompilerImpl : public TECompilerNode { // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { - CCacheValue value = LowerInternal(key, GlobalVarSupply()); + CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply(""))); if (value->packed_func != nullptr) { return value->packed_func; } @@ -527,7 +527,7 @@ class TECompilerImpl : public TECompilerNode { /*! \brief compiler cache lock*/ std::mutex mutex_; /*! \brief internal GlobalVarSupply to get unique GlobalVars */ - GlobalVarSupply global_var_supply; + GlobalVarSupply global_var_supply = GlobalVarSupply(); /*! \brief internal compiler cache */ std::unordered_map cache_; /*! \brief internal compiler cache for shape funcs */ diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 07ebc68974634..25111cec8eda1 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -126,7 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply()); + tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply(NameSupply(""))); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/meta_schedule_layout_rewrite.cc b/src/relay/transforms/meta_schedule_layout_rewrite.cc index a6682f2cf4898..8a70f224c611e 100644 --- a/src/relay/transforms/meta_schedule_layout_rewrite.cc +++ b/src/relay/transforms/meta_schedule_layout_rewrite.cc @@ -127,7 +127,7 @@ Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode* call) { if (const auto* func = call->op.as()) { LayoutIndexQueue* self = LayoutIndexQueue::Global(); self->queue_.clear(); - tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply()); + tec::PrimFuncFor(GetRef(func), Target::Current(), GlobalVarSupply(NameSupply(""))); if (!self->queue_.empty()) { std::deque queue = std::move(self->queue_); self->queue_.clear(); diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 66c56ad4222f0..75833fd93629d 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -28,7 +28,7 @@ namespace tvm { namespace codegen { void CodeGenSourceBase::ClearFuncState() { - name_supply_->Clear(); + name_supply_ = NameSupply(""); ssa_assign_map_.clear(); var_idmap_.clear(); scope_mark_.clear(); diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 4a764eeb23317..2fd0abcd68a63 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -123,7 +123,7 @@ class CodeGenSourceBase { /*! \brief name of each variable */ std::unordered_map var_idmap_; /*! \brief NameSupply for allocation */ - NameSupply name_supply_ = NameSupply(); + NameSupply name_supply_ = NameSupply(""); private: /*! \brief assignment map of ssa */ diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index ccb7a56cc8daa..6521246ff41f5 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -63,7 +63,7 @@ struct CreateFuncInfo { /*! \brief The buffers should be allocated at function root. */ Array root_alloc; /*! \brief The NameSupply to make block name unique. */ - NameSupply name_supply = NameSupply(); + NameSupply name_supply = NameSupply(""); String FreshName(String base_name) { return name_supply->FreshName(base_name); } diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 181a1fa3de4cb..3d2adb235546a 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -52,7 +52,7 @@ TEST(BuildModule, Basic) { auto target = Target("llvm"); - auto lowered = LowerSchedule(s, args, "func", binds, GlobalVarSupply()); + auto lowered = LowerSchedule(s, args, "func", binds, GlobalVarSupply(NameSupply(""))); auto module = build(lowered, target, Target()); auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali"); @@ -121,7 +121,7 @@ TEST(BuildModule, Heterogeneous) { auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - GlobalVarSupply global_var_supply = GlobalVarSupply(); + GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply("")); auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds, global_var_supply); auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds, global_var_supply); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; diff --git a/tests/cpp/c_codegen_test.cc b/tests/cpp/c_codegen_test.cc index 8dbd7d95e3c4a..442f76a8cff3c 100644 --- a/tests/cpp/c_codegen_test.cc +++ b/tests/cpp/c_codegen_test.cc @@ -52,7 +52,8 @@ TEST(CCodegen, MainFunctionOrder) { auto args = Array({A, B, elemwise_add}); std::unordered_map binds; - auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds, GlobalVarSupply()); + auto lowered = + LowerSchedule(fcreate(), args, "elemwise_add", binds, GlobalVarSupply(NameSupply(""))); Map inputs = {{target_c, lowered}}; runtime::Module module = build(inputs, Target()); Array functions = module->GetFunction("get_func_names", false)(); @@ -81,7 +82,8 @@ auto BuildLowered(std::string op_name, tvm::Target target) { auto args = Array({A, B, op}); std::unordered_map binds; - auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds, GlobalVarSupply()); + auto lowered_s = + LowerSchedule(fcreate_s(), args, op_name, binds, GlobalVarSupply(NameSupply(""))); return lowered_s; } diff --git a/tests/cpp/name_supply_test.cc b/tests/cpp/name_supply_test.cc index 1bf167b03a871..ff2707cbfbdfd 100644 --- a/tests/cpp/name_supply_test.cc +++ b/tests/cpp/name_supply_test.cc @@ -52,10 +52,14 @@ TEST(NameSupply, ReserveName) { EXPECT_TRUE(name_supply->ContainsName("otherTest", false)); EXPECT_FALSE(name_supply->ContainsName("otherTest")); + + name_supply->ReserveName("otherTest"); + EXPECT_TRUE(name_supply->ContainsName("prefix_otherTest", false)); + EXPECT_TRUE(name_supply->ContainsName("otherTest")); } GlobalVarSupply preambleVarSupply() { - GlobalVarSupply global_var_supply = GlobalVarSupply(); + GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply("")); global_var_supply->FreshGlobal("test"); return global_var_supply; }