Skip to content

Commit

Permalink
- Masa & Andrew's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Jun 27, 2022
1 parent 0383153 commit cd1318b
Show file tree
Hide file tree
Showing 16 changed files with 61 additions and 65 deletions.
6 changes: 3 additions & 3 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,12 +519,12 @@ constexpr const char* kConstantMemoryPools = "constant_memory_pools";

/*
* \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The
* node will record the index into this array. See also kConstNameToNDArray below, which is
* node will record the index into this array. See also kConstNameToConstant below, which is
* the analog for Realy Functions.
*
* Type: Array<runtime::NDArray>
*/
constexpr const char* kConstantsArray = "Constants";
constexpr const char* kConstants = "constants";

/*!
* \brief All the runtime::Modules accumulated during compilation by external codegen. These
Expand All @@ -542,7 +542,7 @@ constexpr const char* kExternalMods = "external_mods";
*
* Type: Map<String, runtime::NDArray>
*/
constexpr const char* kConstNameToNDArray = "const_name_to_ndarray";
constexpr const char* kConstNameToConstant = "const_name_to_constant";

} // namespace attr
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class Interpreter(Executor):
The runtime device to run the code on.
target : tvm.Target
The target option to build the function.
The target option to build the function. Only homogeneous execution is supported.
CAUTION: Despite the API the module is prepared upon each call to evaluate
rather than once in create_executor.
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ class VMExecutor(Executor):
device : :py:class:`~tvm.runtime.Device`
The runtime device to run the code on.
target : :py:class:`Target`
The target option to build the function.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, target):
Expand Down
28 changes: 17 additions & 11 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,15 +570,16 @@ class GraphExecutor(_interpreter.Executor):
device : :py:class:`Device`
The runtime device to run the code on.
raw_targets : Array[tvm.target.Target]
The available targets.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, raw_targets):
def __init__(self, mod, device, target):
assert mod is not None
self.mod = mod
self.device = device
self.raw_targets = raw_targets
self.target = target

def _make_executor(self, expr=None):
if expr:
Expand All @@ -589,7 +590,7 @@ def _make_executor(self, expr=None):
raise ValueError(
"Graph Executor only supports static graphs, got output type", ret_type
)
mod = build(self.mod, target=self.raw_targets)
mod = build(self.mod, target=self.target)
gmodule = _graph_executor.GraphModule(mod["default"](self.device))

def _unflatten(flat_iter, cur_type):
Expand Down Expand Up @@ -630,16 +631,16 @@ class AotExecutor(_interpreter.Executor):
device : :py:class:`Device`
The runtime device to run the code on.
raw_targets : Array[tvm.target.Target]
The available targets.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, raw_targets):
def __init__(self, mod, device, target):
assert mod is not None
self.mod = mod
self.device = device
self.raw_targets = raw_targets
assert raw_targets[0].attrs.get("executor", "graph") == "aot"
self.target = target

def _make_executor(self, expr=None):
if expr:
Expand All @@ -648,7 +649,7 @@ def _make_executor(self, expr=None):
ret_type = self.mod["main"].checked_type.ret_type
if _ty.is_dynamic(ret_type):
raise ValueError("AOT Executor only supports static graphs, got output type", ret_type)
mod = build(self.mod, target=self.raw_targets)
mod = build(self.mod, target=self.target)

# NOTE: Given AOT requires use of the "c" backend, must export/import to compile the
# generated code.
Expand Down Expand Up @@ -722,6 +723,8 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
CAUTION: Though this API allows multiple targets, it does not allow multiple devices, so
heterogenous compilation is not yet supported.
params : dict of str to NDArray
Input parameters to the graph that do not change
Expand All @@ -737,11 +740,14 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
if device is not None:
assert device.device_type == raw_targets[0].kind.device_type
else:
# Use the first target as the device.
device = _nd.device(raw_targets[0].kind.device_type, 0)

if params is not None:
mod = IRModule.from_expr(bind_params_by_name(mod["main"], params))

assert raw_targets[0].attrs.get("executor") == kind

if kind == "debug":
assert len(raw_targets) == 1, "The interpreter currently only supports a single target"
return _interpreter.Interpreter(mod, device, raw_targets[0])
Expand Down
16 changes: 6 additions & 10 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1170,12 +1170,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {

// Collect any constants extracted by external codegen.
ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
Map<String, runtime::NDArray> const_name_to_ndarray =
lowered_mod
->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToNDArray,
Map<String, runtime::NDArray>())
.value();
for (const auto& kv : const_name_to_ndarray) {
Map<String, runtime::NDArray> const_name_to_constant =
lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
.value_or({});
for (const auto& kv : const_name_to_constant) {
ICHECK(ret.params.emplace(kv.first, kv.second).second);
}

Expand Down Expand Up @@ -1223,10 +1221,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

// Collect any runtime modules generated by external codegen.
ret.external_mods = lowered_mod
->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods,
Array<tvm::runtime::Module>())
.value();
ret.external_mods =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods).value_or({});

// This is the point where we separate the functions in the module by target
VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ runtime::Module ACLCompiler(const ObjectRef& ref) {
serializer.serialize();
std::string graph_json = serializer.GetJSON();

// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/bnns/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) {
serializer.serialize();
std::string graph_json = serializer.GetJSON();

// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.
Expand Down
13 changes: 7 additions & 6 deletions src/relay/backend/contrib/codegen_json/codegen_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
* \brief Returns the accumulated map from constant names to the NDArray they must be bound to
* at runtime. Also referred to a 'params' elsewhere in the code.
*/
const std::unordered_map<std::string, runtime::NDArray>& const_name_to_ndarray() const {
return const_name_to_ndarray_;
const std::unordered_map<std::string, runtime::NDArray>& const_name_to_constant() const {
return const_name_to_constant_;
}

/*!
Expand Down Expand Up @@ -260,9 +260,10 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn

std::vector<JSONGraphNodeEntry> VisitExpr_(const ConstantNode* constant_node) {
std::string name = symbol_ + "_const_" + std::to_string(const_names_.size());
VLOG(1) << "Will require parameter '" << name << "' to be available at runtime";
ICHECK_EQ(const_name_to_ndarray_.count(name), 0);
const_name_to_ndarray_.emplace(name, constant_node->data);
VLOG(1) << "Will require parameter '" << name
<< "' to be supplied by the ConstLoaderModule at runtime";
ICHECK_EQ(const_name_to_constant_.count(name), 0);
const_name_to_constant_.emplace(name, constant_node->data);
const_names_.push_back(name);
auto node = std::make_shared<JSONGraphNode>(name, /*op_type=*/"const");
return AddNode(node, GetRef<Expr>(constant_node));
Expand Down Expand Up @@ -361,7 +362,7 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
* translation to JSON. The JSON will record only the constant name. The actual NDArray must
* be made available at runtime from a ConstLoaderModule.
*/
std::unordered_map<std::string, runtime::NDArray> const_name_to_ndarray_;
std::unordered_map<std::string, runtime::NDArray> const_name_to_constant_;
/*!
* \brief The domain of the above map, but in order the constants were encountered during
* translation.
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ namespace {
Target GetCutlassTarget() {
Target target = Target::Current(/*allow_not_defined=*/true);
if (!target.defined() || target->kind->name != "cutlass") {
// Use the default CUTLASS compilation options.
// Use the default CUTLASS compilation options if no specific "cutlass" target was given
// in the overall targets list. In that case target_hooks.cc will invoke the custom pass
// without pushing any target instance onto the implicit target stack.
target = Target("cutlass");
}
return target;
Expand Down Expand Up @@ -912,8 +914,7 @@ transform::Pass CompileForCutlassImpl() {
Target target = GetCutlassTarget();
runtime::Module runtime_mod = (*pf)(mod, target);
Array<runtime::Module> external_mods =
mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods, Array<runtime::Module>())
.value();
mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
external_mods.push_back(runtime_mod);
return WithAttr(mod, tvm::attr::kExternalMods, external_mods);
};
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) {
serializer.serialize();
std::string graph_json = serializer.GetJSON();

// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) {
std::string graph_json = serializer.GetJSON();
VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;

// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/verilator/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ runtime::Module VerilatorBackend(const ObjectRef& ref) {
serializer.serialize();
std::string graph_json = serializer.GetJSON();

// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.
Expand Down
14 changes: 5 additions & 9 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,14 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<

// Collect any runtime modules generated by external codegen.
ret.external_mods =
lowered_mod
->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods, Array<runtime::Module>())
.value();
lowered_mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});

// Collect any constants extracted by external codegen.
ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
Map<String, runtime::NDArray> const_name_to_ndarray =
lowered_mod
->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToNDArray,
Map<String, runtime::NDArray>())
.value();
for (const auto& kv : const_name_to_ndarray) {
Map<String, runtime::NDArray> const_name_to_constant =
lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
.value_or(Map<String, runtime::NDArray>());
for (const auto& kv : const_name_to_constant) {
ICHECK(ret.params.emplace(kv.first, kv.second).second);
}

Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1196,8 +1196,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr
// annotate the module with the resulting runtime modules.
// TODO(mbs): runtime modules should be first class rather than attributes.
Array<runtime::Module> external_mods =
module->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods, Array<runtime::Module>())
.value();
module->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
Array<runtime::Module> new_external_mods = compiler->LowerExternalFunctions();
VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size()
<< " new external modules";
Expand Down
16 changes: 6 additions & 10 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1170,24 +1170,20 @@ void VMCompiler::Codegen() {
// Retrieve all external runtime modules accumulated by external codegen (both function-at-a-time
// and IRModule-at-a-time).
Array<runtime::Module> external_mods =
context_.module
->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods, Array<runtime::Module>())
.value();
context_.module->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});

// Retrieve any constant bindings accumulated by external codegen (by IRModule-at-a-time passes).
Map<String, runtime::NDArray> const_name_to_ndarray =
context_.module
->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToNDArray,
Map<String, runtime::NDArray>())
.value();
Map<String, runtime::NDArray> const_name_to_constant =
context_.module->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
.value_or({});

VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build, "
<< external_mods.size() << " external runtime modules, " << const_name_to_ndarray.size()
<< external_mods.size() << " external runtime modules, " << const_name_to_constant.size()
<< " external constants, and " << params_.size() << " local constants";

// Any constant bindings must be merged into the overall 'params' map we've directly accumulated
// via the TECompiler callback.
for (const auto& kv : const_name_to_ndarray) {
for (const auto& kv : const_name_to_constant) {
ICHECK_EQ(params_.count(kv.first), 0);
params_.emplace(kv.first, kv.second);
}
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ tvm::transform::Pass ExtractPrimFuncConstants() {
}
auto* attrs = m->attrs.CopyOnWrite();
ConstArrayType constant_array_ =
(attrs->dict.count(tvm::attr::kConstantsArray))
? Downcast<ConstArrayType>(attrs->dict[tvm::attr::kConstantsArray])
(attrs->dict.count(tvm::attr::kConstants))
? Downcast<ConstArrayType>(attrs->dict[tvm::attr::kConstants])
: ConstArrayType();
Applicator a = Applicator();
func->body = a.Apply(func->body, constant_array_);
const ConstArrayType constant_list = a.constant_array_;
if (constant_list.size()) {
attrs->dict.Set(tvm::attr::kConstantsArray, constant_list);
attrs->dict.Set(tvm::attr::kConstants, constant_list);
}
return GetRef<PrimFunc>(func);
};
Expand Down

0 comments on commit cd1318b

Please sign in to comment.