Skip to content

Commit

Permalink
Add documentation for NameSupply and GlobalVarSupply
Browse files Browse the repository at this point in the history
  • Loading branch information
Florin-Gabriel Blanaru committed Jul 27, 2022
1 parent 6962c2f commit 30c27af
Show file tree
Hide file tree
Showing 20 changed files with 213 additions and 71 deletions.
64 changes: 54 additions & 10 deletions include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
* under the License.
*/

/*!
* \file tvm/ir/global_var_supply.h
* \brief GlobalVarSupply that can be used to generate unique \class GlobalVar.
*/
#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_
#define TVM_IR_GLOBAL_VAR_SUPPLY_H_

Expand All @@ -29,20 +33,49 @@

namespace tvm {

/*!
* \brief GlobalVarSupply can be used to generate unique GlobalVars.
*/
class GlobalVarSupplyNode : public Object {
public:
/*!
* \brief Empty constructor. Will use an empty NameSupply.
*/
GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}

/*!
* \brief Constructor.
* \param name_supply The NameSupply to use for generating the names of fresh GlobalVars.
*/
explicit GlobalVarSupplyNode(NameSupply name_supply);

/*!
* \brief Generates a unique GlobalVar from this supply.
* \param name The name from which the name of the GlobalVar is derived.
* \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended
* to the name. \return A unique GlobalVar.
*/
GlobalVar FreshGlobal(String name, bool add_prefix = true);

/*!
* \brief Looks up for a GlobalVar with the given name in this supply.
* If no entry is found, creates one, places it in the cache and returns it.
* \param name The name of the GlobalVar to search for.
* \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to
* the name before performing the search. \return A cached GlobalVar.
*/
GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);

/*!
* \brief Reserves an existing GlobalVar with this supply.
* \param var The GlobalVar to be registered.
* \param allow_conflict Allow conflict with other GlobalVars that have the same name.
*/
void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false);

void VisitAttrs(AttrVisitor* v) { v->Visit("name_supply", &name_supply_); }

/*! \brief The NameSupply used to generate unique name hints to GlobalVars. */
NameSupply name_supply_;

static constexpr const char* _type_key = "GlobalVarSupply";
Expand All @@ -56,24 +89,35 @@ class GlobalVarSupplyNode : public Object {
friend class GlobalVarSupply;
};

/*!
* \brief Managed reference class to GlobalVarSupplyNode.
* \sa GlobalVarSupplyNode
*/
class GlobalVarSupply : public ObjectRef {
public:
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply(),
/*!
* \brief Constructor.
* \param name_supply The NameSupply to be used when generating new GlobalVars.
* \param name_to_var_map An optional map.
*/
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply,
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

/*!
* \brief Constructs a supply from an array of IRModules. GlobalVars generated by this supply are
* guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array
* of IRModules.
*/
TVM_DLL explicit GlobalVarSupply(const Array<IRModule>& modules);

/*!
* \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are
* guaranteed not to conflict with GlobalVars that belong to the modules. \param module The
* IRModule.
*/
TVM_DLL explicit GlobalVarSupply(const IRModule module);

explicit GlobalVarSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
GlobalVarSupplyNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<GlobalVarSupplyNode*>(ptr);
}

TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarSupplyNode);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, GlobalVarSupplyNode);
};

} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ namespace attr {
*
* \sa tvm::runtime::String
*/
constexpr const char* kModuleName = "name";
constexpr const char* kModuleName = "mod_name";

/*!
* \brief Executor targeted by the module
Expand Down
66 changes: 46 additions & 20 deletions include/tvm/ir/name_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
* under the License.
*/

/*!
* \file tvm/ir/name_supply.h
* \brief NameSupply that can be used to generate unique variable names.
*/
#ifndef TVM_IR_NAME_SUPPLY_H_
#define TVM_IR_NAME_SUPPLY_H_

Expand All @@ -27,20 +31,37 @@

namespace tvm {

/*!
* \brief NameSupply can be used to generate unique names.
*/
class NameSupplyNode : public Object {
public:
NameSupplyNode() : NameSupplyNode("") {}

explicit NameSupplyNode(const String& prefix);

/*!
* \brief Generates a unique name from this NameSupply.
* \param name The name from which the generated name is derived.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name. \return A unique name.
*/
String FreshName(const String& name, bool add_prefix = true);

/*!
* \brief Reserves an existing name with this NameSupply.
* \param name The name to be reserved.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name before reserving it. \return The name that was reserved with the NameSupply. It can be
* different if a prefix is added.
*/
String ReserveName(const String& name, bool add_prefix = true);

/*!
* \brief Checks if this NameSupply already generated a name.
* \param name The name to check.
* \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the
* name before checking for it. \return True if the name has already been generated. False
* otherwise.
*/
bool ContainsName(const String& name, bool add_prefix = true);

void Clear();

void VisitAttrs(AttrVisitor* v) { v->Visit("prefix", &prefix_); }

// Prefix for all GlobalVar names. It can be empty.
Expand All @@ -52,32 +73,37 @@ class NameSupplyNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object);

private:
String prefix_module_name(const String& name);

/*! \brief Helper function to add the NameSupply prefix to the name. */
String add_prefix_to_name(const String& name);

/*!
* \brief Function that will generate a unique name.
* \param name The name to be used as a base.
* \return A unique name.
*/
std::string GetUniqueName(std::string name);

// Key is function_name. Value is a counter.
/*! \brief A map that is used to generate unique names. */
std::unordered_map<std::string, int> name_map;

friend class NameSupply;
};

/*!
* \brief Managed reference class to NameSupplyNode.
* \sa NameSupplyNode
*/
class NameSupply : public ObjectRef {
public:
TVM_DLL explicit NameSupply();

/*!
* \brief Constructor.
* \param prefix The prefix to be used with this NameSupply.
* \param name_map An optional map.
*/
TVM_DLL explicit NameSupply(const String& prefix,
std::unordered_map<std::string, int> name_map = {});

explicit NameSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
NameSupplyNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<NameSupplyNode*>(ptr);
}

TVM_DEFINE_OBJECT_REF_COW_METHOD(NameSupplyNode);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode);
};

} // namespace tvm
Expand Down
70 changes: 66 additions & 4 deletions python/tvm/ir/supply.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,52 @@ def __init__(self, prefix=""):
self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix)

def fresh_name(self, name, add_prefix=True):
"""Generates a unique name from this NameSupply.
Parameters
----------
name: String
The name from which the generated name is derived.
add_prefix: bool
If set to true, then the prefix of this NameSupply will be prepended to the name.
"""
return _ffi_api.NameSupply_FreshName(self, name, add_prefix)

def reserve_name(self, name, add_prefix=True):
"""Reserves an existing name with this NameSupply.
Parameters
----------
name: String
The name to be reserved.
add_prefix: bool
If set to true, then the prefix of this NameSupply will be prepended to the name
before reserving it.
"""
return _ffi_api.NameSupply_ReserveName(self, name, add_prefix)

def contains_name(self, name, add_prefix=True):
return _ffi_api.NameSupply_ContainsName(self, name, add_prefix)
"""Checks if this NameSupply already generated a name.
def clear(self):
return _ffi_api.NameSupply_Clear(self)
Parameters
----------
name: String
The name to check.
add_prefix: bool
If set to true, then the prefix of this NameSupply will be prepended to the name
before checking for it.
"""
return _ffi_api.NameSupply_ContainsName(self, name, add_prefix)


@tvm._ffi.register_object("GlobalVarSupply")
class GlobalVarSupply(Object):
"""GlobalVarSupply that holds a mapping between names and GlobalVars.
GlobalVarSupply can be used to generate new GlobalVars with an unique name.
GlobalVarSupply can be used to generate new GlobalVars with a unique name.
It also can be used to retrieve previously generated GlobalVars based on a name.
Parameters
Expand All @@ -70,10 +99,43 @@ def __init__(self, value=None):
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value)

def fresh_global(self, name, add_prefix=True):
"""Generates a unique GlobalVar from this supply.
Parameters
----------
name: String
The name from which the name of the GlobalVar is derived.
add_prefix: bool
If set to true, then the prefix of the contained NameSupply will be prepended
to the name.
"""
return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix)

def unique_global_for(self, name, add_prefix=True):
"""Looks up for a GlobalVar with the given name in this supply. If no entry is found
, creates one, places it in the cache and returns it.
Parameters
----------
name: String
The name of the GlobalVar to search for.
add_prefix: bool
If set to true, the prefix of the contained NameSupply will be prepended to the
name before performing the search.
"""
return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix)

def reserve_global(self, global_var, allow_conflict=False):
"""Reserves an existing GlobalVar with this supply.
Parameters
----------
global_var: GlobalVar
The GlobalVar to be registered.
allow_conflict: bool
Allow conflict with other GlobalVars that have the same name
"""
return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, allow_conflict)
3 changes: 2 additions & 1 deletion src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,8 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
auto pass_ctx = tvm::transform::PassContext::Current();

auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name,
std::unordered_map<te::Tensor, te::Buffer>(), GlobalVarSupply());
std::unordered_map<te::Tensor, te::Buffer>(),
GlobalVarSupply(NameSupply("")));

bool disable_vectorize =
pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
/*! \brief Print the current indent spaces. */
inline void PrintIndent();
/*! \brief NameSupply for allocated ids. */
NameSupply ids_allocated = NameSupply();
NameSupply ids_allocated = NameSupply("");
/*!
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
Expand Down
6 changes: 4 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
c_binds.insert({kv.first, kv.second});
}
}
IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply());
IRModule mod =
ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")));
return mod;
});

Expand Down Expand Up @@ -366,7 +367,8 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
c_binds.insert({kv.first, kv.second});
}
}
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(), simple_mode);
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")),
simple_mode);
});

/**
Expand Down
6 changes: 5 additions & 1 deletion src/ir/global_var_supply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
* under the License.
*/

/*!
* \file global_var_supply.cc
* \brief GlobalVarSupply that can be used to generate unique GlobalVars.
*/
#include "tvm/ir/global_var_supply.h"

#include <tvm/runtime/registry.h>
Expand All @@ -37,7 +41,7 @@ std::string GetModuleName(const IRModule& module) {
return module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
}

GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : GlobalVarSupply() {
GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : GlobalVarSupply(NameSupply("")) {
if (!modules.empty()) {
IRModule first_mod = modules.front();
this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod);
Expand Down
Loading

0 comments on commit 30c27af

Please sign in to comment.