Skip to content

Commit

Permalink
Construct GlobalVarSupply from IRModule
Browse files Browse the repository at this point in the history
  • Loading branch information
Florin-Gabriel Blanaru committed Jul 20, 2022
1 parent fa25ace commit 84ae2ef
Show file tree
Hide file tree
Showing 23 changed files with 84 additions and 141 deletions.
1 change: 1 addition & 0 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#ifndef TVM_DRIVER_DRIVER_API_H_
#define TVM_DRIVER_DRIVER_API_H_

#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <unordered_map>

#include "tvm/ir/expr.h"
#include "tvm/ir/module.h"
#include "tvm/ir/name_supply.h"

namespace tvm {
Expand Down Expand Up @@ -57,13 +58,12 @@ class GlobalVarSupplyNode : public Object {

class GlobalVarSupply : public ObjectRef {
public:
TVM_DLL explicit GlobalVarSupply(
const NameSupply& name_supply = NameSupply::NameSupplyWithPrefix(""),
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply(),
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

TVM_DLL static GlobalVarSupply GlobalVarSupplyFromNameSupply(const NameSupply& name_supply);
TVM_DLL explicit GlobalVarSupply(const Array<IRModule>& modules);

TVM_DLL static GlobalVarSupply EmptySupply();
TVM_DLL explicit GlobalVarSupply(const IRModule module);

explicit GlobalVarSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
Expand Down
1 change: 0 additions & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/ir/adt.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/type.h>
#include <tvm/parser/source_map.h>
#include <tvm/runtime/container/array.h>
Expand Down
6 changes: 1 addition & 5 deletions include/tvm/ir/name_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,11 @@ class NameSupplyNode : public Object {

class NameSupply : public ObjectRef {
public:
TVM_DLL NameSupply();
TVM_DLL explicit NameSupply();

TVM_DLL explicit NameSupply(const String& prefix,
std::unordered_map<std::string, int> name_map = {});

TVM_DLL static NameSupply NameSupplyWithPrefix(const String& prefix = "");

TVM_DLL static NameSupply EmptySupply();

explicit NameSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
NameSupplyNode* operator->() const {
Expand Down
25 changes: 20 additions & 5 deletions python/tvm/ir/supply.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Suppliers that are used to guarantee uniqueness of names and GlobalVars."""
import tvm
from tvm import Object
from tvm import Object, IRModule
from . import _ffi_api


Expand All @@ -38,6 +38,12 @@ def fresh_name(self, name, add_prefix=True):
def reserve_name(self, name, add_prefix=True):
return _ffi_api.NameSupply_ReserveName(self, name, add_prefix)

def contains_name(self, name, add_prefix=True):
return _ffi_api.NameSupply_ContainsName(self, name, add_prefix)

def clear(self):
return _ffi_api.NameSupply_Clear(self)


@tvm._ffi.register_object("GlobalVarSupply")
class GlobalVarSupply(Object):
Expand All @@ -48,15 +54,24 @@ class GlobalVarSupply(Object):
Parameters
----------
name_supply: The NameSupply to be used by this GlobalVarSupply.
value: Union[List[IRModule], IRModule, NameSupply]
The IRModules used to build this GlobalVarSupply or a NameSupply.
"""

def __init__(self, name_supply=None):
name_supply = name_supply if name_supply is not None else NameSupply("")
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply, name_supply)
def __init__(self, value=None):
if value is None:
name_supply = NameSupply("")
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, name_supply)
elif isinstance(value, (list, tvm.container.Array)):
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value)
elif isinstance(value, IRModule):
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value)

def fresh_global(self, name, add_prefix=True):
return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix)

def unique_global_for(self, name, add_prefix=True):
return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix)

def reserve_global(self, global_var, allow_conflict=False):
return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, allow_conflict)
4 changes: 2 additions & 2 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/measure_record.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <tvm/te/operation.h>
Expand Down Expand Up @@ -1371,8 +1372,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
auto pass_ctx = tvm::transform::PassContext::Current();

auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name,
std::unordered_map<te::Tensor, te::Buffer>(),
GlobalVarSupply::EmptySupply());
std::unordered_map<te::Tensor, te::Buffer>(), GlobalVarSupply());

bool disable_vectorize =
pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
Expand Down
3 changes: 2 additions & 1 deletion src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
#define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_

#include <tvm/ir/name_supply.h>
#include <tvm/target/codegen.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
Expand Down Expand Up @@ -146,7 +147,7 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
/*! \brief Print the current indent spaces. */
inline void PrintIndent();
/*! \brief NameSupply for allocated ids. */
NameSupply ids_allocated = NameSupply::EmptySupply();
NameSupply ids_allocated = NameSupply();
/*!
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
Expand Down
6 changes: 2 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
c_binds.insert({kv.first, kv.second});
}
}
IRModule mod =
ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply::EmptySupply());
IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply());
return mod;
});

Expand Down Expand Up @@ -367,8 +366,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
c_binds.insert({kv.first, kv.second});
}
}
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply::EmptySupply(),
simple_mode);
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(), simple_mode);
});

/**
Expand Down
35 changes: 27 additions & 8 deletions src/ir/global_var_supply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,32 @@
#include "tvm/ir/expr.h"

namespace tvm {

GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply,
std::unordered_map<std::string, GlobalVar> name_to_var_map) {
auto n = make_object<GlobalVarSupplyNode>(name_supply);
n->name_to_var_map_ = std::move(name_to_var_map);
data_ = std::move(n);
}

GlobalVarSupply GlobalVarSupply::GlobalVarSupplyFromNameSupply(const NameSupply& name_supply) {
auto global_var_supply = GlobalVarSupply(name_supply);
return global_var_supply;
std::string GetModuleName(const IRModule& module) {
return module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
}

GlobalVarSupply GlobalVarSupply::EmptySupply() {
return GlobalVarSupplyFromNameSupply(NameSupply::NameSupplyWithPrefix(""));
GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : GlobalVarSupply() {
if (!modules.empty()) {
IRModule first_mod = modules.front();
this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod);
}
for (auto& mod : modules) {
for (auto kv : mod->functions) {
this->operator->()->ReserveGlobalVar(kv.first);
}
}
}

GlobalVarSupply::GlobalVarSupply(const IRModule module)
: GlobalVarSupply(Array<IRModule>{module}) {}

void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) {
name_supply_->ReserveName(var->name_hint, false);
if (!allow_conflict) {
Expand Down Expand Up @@ -79,8 +88,15 @@ GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) {

TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode);

TVM_REGISTER_GLOBAL("ir.GlobalVarSupply").set_body_typed([](NameSupply name_supply) {
return GlobalVarSupply(name_supply);
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply")
.set_body_typed([](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); });

TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) {
return GlobalVarSupply(std::move(mod));
});

TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules").set_body_typed([](const Array<IRModule>& mods) {
return GlobalVarSupply(mods);
});

TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal")
Expand All @@ -89,4 +105,7 @@ TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal")
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor")
.set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::UniqueGlobalFor);

TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar")
.set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::ReserveGlobalVar);

} // namespace tvm
4 changes: 2 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#include <sstream>
#include <unordered_set>

#include "../relay/backend/supply_provider.h"
#include "tvm/ir/global_var_supply.h"

namespace tvm {

Expand Down Expand Up @@ -386,7 +386,7 @@ std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
}

GlobalVar main_gv;
auto global_var_supply = tvm::BuildGlobalVarSupply(mod);
auto global_var_supply = GlobalVarSupply(mod);
if (gv_name.empty()) {
// Bind function to 'main' (though rename if would clash with existing 'main').
main_gv = global_var_supply->FreshGlobal("main", false);
Expand Down
14 changes: 6 additions & 8 deletions src/ir/name_supply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,14 @@

namespace tvm {

NameSupply::NameSupply() { NameSupply("", {}); }
NameSupply::NameSupply() : NameSupply("") {}

NameSupply::NameSupply(const String& prefix, std::unordered_map<std::string, int> name_map) {
auto n = make_object<NameSupplyNode>(prefix);
n->name_map = std::move(name_map);
data_ = std::move(n);
}

NameSupply NameSupply::NameSupplyWithPrefix(const String& prefix) {
auto name_supply = NameSupply(prefix);
return name_supply;
}

NameSupply NameSupply::EmptySupply() { return NameSupply::NameSupplyWithPrefix(""); }

NameSupplyNode::NameSupplyNode(const String& prefix) : prefix_(prefix) {}

String NameSupplyNode::ReserveName(const String& name, bool add_prefix) {
Expand Down Expand Up @@ -114,4 +107,9 @@ TVM_REGISTER_GLOBAL("ir.NameSupply_FreshName")
TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName")
.set_body_method<NameSupply>(&NameSupplyNode::ReserveName);

TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName")
.set_body_method<NameSupply>(&NameSupplyNode::ContainsName);

TVM_REGISTER_GLOBAL("ir.NameSupply_Clear").set_body_method<NameSupply>(&NameSupplyNode::Clear);

} // namespace tvm
2 changes: 1 addition & 1 deletion src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
/*! \brief function metadata */
Map<String, FunctionInfo> function_metadata_;
/*! \brief NameSupply */
NameSupply name_supply_ = NameSupply::EmptySupply();
NameSupply name_supply_ = NameSupply();
};

class GraphExecutorCodegenModule : public runtime::ModuleNode {
Expand Down
51 changes: 0 additions & 51 deletions src/relay/backend/supply_provider.cc

This file was deleted.

34 changes: 0 additions & 34 deletions src/relay/backend/supply_provider.h

This file was deleted.

2 changes: 1 addition & 1 deletion src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Array<meta_schedule::ExtractedTask> ExtractTask(
});
// Tasks are extracted via post order visit, return the reversed list.
std::reverse(tasks.begin(), tasks.end());
NameSupply name_supply = NameSupply::EmptySupply();
NameSupply name_supply = NameSupply();
for (ExtractedTask task : tasks) {
task->task_name = name_supply->FreshName(task->task_name);
}
Expand Down
Loading

0 comments on commit 84ae2ef

Please sign in to comment.