Skip to content

Commit

Permalink
[BYOC] Handle constants in IRModule-at-a-time external codegen
Browse files Browse the repository at this point in the history
I tried to do to the TensorRT integration what #11631 did to the CUTLASS integration, viz:
 - Make sure all compilation options are passed in Target instances. This helps Collage.
 - Use a custom pass invoked via RelayToTIRTargetHooks instead of the relay.ext.$toolchain mechanism.
   This helps use decouple external codegen from lowering.

This PR collects the prep for that change:
 - TensorRT uses the JSONSerializer visitor to encode each partition function. Previously, when the
   visitor encountered a Constant it simply generated and recorded a name for the constant. Then,
   completely separately, and via a callback in TECompiler, the function is visited again in the
   same order and with the same name generation convention by a ConstantUpdater to actually collect the
   bindings, which are then encoded into a ConstLoaderModule to be made available at runtime.

   However if all TensorRT compilation is to be done by a stand-alone pass there's no TECompiler callback
   hackery available. So I've added a "const_name_to_ndarray" attribute to the IRModule of type
   Map<String, runtime::NDArray> so that named constants can be accumulated throughout compilation by
   any pass which needs to do so. Then the Graph, AOT and VM executors are all updated to merge those
   constants into the final runtime artifact

   (Compare with "Constants", the equivalent attribute for extracting TIR AllocateConsts.)

 - The TensorRT tests use the create_executor interface but it wasn't quite ready for the
   new more general form of passing list-of-targets.

 - I want TensorRT compilation to work out of the box without the need for any special targets if
   all the default options should apply. Go back and make the CUTLASS integration I did follow the
   same convention.

 - TensorRT actually needs to 'undo' partitionings in some situations. Add an InlineCompilerFunctions
   pass to make that robust. In particular, it must undo both the 'partitioning' (ie separating out
   the "Compiler" function) and any 'compositing' (ie separating out small sub-graphs as
   "Composite" functions).
  • Loading branch information
mbs-octoml committed Jun 17, 2022
1 parent dffc310 commit 5de3adf
Show file tree
Hide file tree
Showing 21 changed files with 428 additions and 148 deletions.
28 changes: 26 additions & 2 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,10 @@ TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,

namespace attr {

// Following are attributes for IRModule only.

/*!
* \brief Executor targetted by the module
* \brief Executor targeted by the module
*
* Type: Executor
*
Expand All @@ -507,10 +509,32 @@ constexpr const char* kRuntime = "runtime";
constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";

/*
* \brief Module attribute for tir constants
* \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The
* node will record the index into this array. See also kConstNameToNDArray below, which is
* the analog for Realy Functions.
*
* Type: Array<runtime::NDArray>
*/
constexpr const char* kConstantsArray = "Constants";

/*!
* \brief All the runtime::Modules accumulated during compilation by external codegen. These
* modules must be either directly linked or captured in the final compilation artifact.
*
* Type: Array<runtime::Module>
*/
constexpr const char* kExternalMods = "external_mods";

/*!
* \brief All the named runtime::NDArrays accumulated during compilation by external codegen.
* Generally the associated runtime::Module will indicate it requires bindings for these names,
* and during module initialization these bindings will be recovered from a ConstLoaderModule.
* See also kConstantsArray above, which is the analog for PrimFuncs.
*
* Type: Map<String, runtime::NDArray>
*/
constexpr const char* kConstNameToNDArray = "const_name_to_ndarray";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_MODULE_H_
16 changes: 8 additions & 8 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,24 +707,24 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
-------
executor : :py:class:`~tvm.relay.backend.interpreter.Executor`
"""
raw_targets = Target.canon_multi_target(target)
if mod is None:
mod = IRModule()
if device is not None:
assert device.device_type == _nd.device(str(target), 0).device_type
assert device.device_type == raw_targets[0].kind.device_type
else:
device = _nd.device(str(target), 0)
device = _nd.device(raw_targets[0].kind.device_type, 0)

if params is not None:
mod = IRModule.from_expr(bind_params_by_name(mod["main"], params))

if isinstance(target, str):
target = Target(target)
if kind == "debug":
return _interpreter.Interpreter(mod, device, target)
assert len(raw_targets) == 1, "The interpreter currently only supports a single target"
return _interpreter.Interpreter(mod, device, raw_targets[0])
if kind == "graph":
return GraphExecutor(mod, device, target)
return GraphExecutor(mod, device, raw_targets)
if kind == "vm":
return VMExecutor(mod, device, target)
return VMExecutor(mod, device, raw_targets)
if kind == "aot":
return AotExecutor(mod, device, target)
return AotExecutor(mod, device, raw_targets)
raise RuntimeError("unknown execution strategy: {0}".format(kind))
26 changes: 24 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""):
Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.
If non-empty, the "Compiler" attribute to filter on.
Returns
-------
Expand All @@ -1412,11 +1412,33 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""):
Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.
If non-empty, the "Compiler" attribute to filter on.
Returns
-------
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.MarkCompilerFunctionsAsExtern(compiler_filter)


def InlineCompilerFunctions(global_vars):
"""Inlines all global functions bound to a global var in global_vars.
Both the global "Compiler" attributed function, and any "Composite" functions it its body are
inlined.
This pass may be useful for external codegen which needs to undo partitioning based on
properties of the entire partition.
Parameters
----------
global_vars : Array[tvm.relay.GlobalVar]
The global vars of all 'Compiler' functions to inline.
Returns
-------
ret : tvm.transform.Pass
The pass.
"""
return _ffi_api.InlineCompilerFunctions(global_vars)
43 changes: 21 additions & 22 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1149,11 +1149,21 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
// to run the LegalizePackedCalls pass.
LoweredOutput ret;
ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
for (auto param : params_) {
ret.params.emplace(std::make_pair(
param.first,
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));

// Collect any constants extracted by external codegen.
ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
Map<String, runtime::NDArray> const_name_to_ndarray =
lowered_mod
->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToNDArray,
Map<String, runtime::NDArray>())
.value();
for (const auto& kv : const_name_to_ndarray) {
ICHECK(ret.params.emplace(kv.first, kv.second).second);
}

// Collect any constants extracted during lowering.
for (const auto& kv : params_) {
ICHECK(ret.params.emplace(kv.first, kv.second).second);
}

// AoT Executor codegen works completely on TIR beyond this point, hence removing relay main
Expand Down Expand Up @@ -1199,9 +1209,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
lowered_mod = pack_calls(lowered_mod);
}

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";
// Collect any runtime modules generated by external codegen.
ret.external_mods = lowered_mod
->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods,
Array<tvm::runtime::Module>())
.value();

// This is the point where we separate the functions in the module by target
VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);
Expand All @@ -1214,8 +1226,6 @@ class AOTExecutorCodegen : public MixedModeVisitor {
<< PrettyPrint(kv.second);
}

ret.external_mods = external_modules.value();

// Extract USMP metadata to pass onto metadata sources
Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
std::vector<tir::Var> pool_vars;
Expand Down Expand Up @@ -1282,11 +1292,6 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
String key = args[0];
*rv = get_param_by_name(key);
});
} else if (name == "get_param_id") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
String key = args[0];
*rv = get_param_id(key);
});
} else if (name == "get_irmodule") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); });
Expand Down Expand Up @@ -1328,17 +1333,11 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
runtime::NDArray get_param_by_name(String key) {
auto it = this->output_.params.find(key);
CHECK(it != this->output_.params.end()) << "no such parameter " << key;
return (*it).second.second;
return (*it).second;
}

Array<tvm::runtime::Module> get_external_modules() { return output_.external_mods; }

int get_param_id(String key) {
auto it = this->output_.params.find(key);
CHECK(it != this->output_.params.end()) << "no such parameter " << key;
return (*it).second.first;
}

Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }

std::shared_ptr<AOTExecutorCodegen> codegen_;
Expand Down
11 changes: 0 additions & 11 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,6 @@ struct ExecutorCodegen {
return ret;
}

std::unordered_map<std::string, int64_t> GetParamIds() {
std::unordered_map<std::string, int64_t> ret;
auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
for (const auto& expr : names) {
// Implicit cast from runtime::String to std::string
std::string key = expr;
ret[key] = CallFunc<int64_t>("get_param_id", key);
}
return ret;
}

Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}
Expand Down
9 changes: 7 additions & 2 deletions src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,15 @@ runtime::Module ACLCompiler(const ObjectRef& ref) {
ACLJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto param_names = serializer.GetParams();

// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.

const auto* pf = runtime::Registry::Get("runtime.arm_compute_lib_runtime_create");
ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
runtime::Module lib = (*pf)(func_name, graph_json, param_names);
runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names());
return lib;
}

Expand Down
8 changes: 6 additions & 2 deletions src/relay/backend/contrib/bnns/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,15 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) {
BNNSJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();

// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.

const auto* pf = runtime::Registry::Get("runtime.BNNSJSONRuntimeCreate");
ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
auto mod = (*pf)(func_name, graph_json, params);
auto mod = (*pf)(func_name, graph_json, serializer.const_names());
return mod;
}

Expand Down
45 changes: 35 additions & 10 deletions src/relay/backend/contrib/codegen_json/codegen_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../../../../runtime/contrib/json/json_node.h"
Expand Down Expand Up @@ -150,7 +152,8 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
* \param symbol The symbol that represents the graph being converted.
* \param expr The Relay expression to be converted to the JSON form.
*/
JSONSerializer(const std::string& symbol, const Expr& expr) : symbol_(symbol), func_(expr) {}
JSONSerializer(std::string symbol, Expr expr)
: symbol_(std::move(symbol)), func_(std::move(expr)) {}

void serialize() {
relay::Function func = Downcast<relay::Function>(func_);
Expand All @@ -162,8 +165,18 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
heads_ = VisitExpr(func->body);
}

/*!\brief Return the required params. */
Array<String> GetParams() const { return params_; }
/*!
* \brief Returns the accumulated map from constant names to the NDArray they must be bound to
* at runtime. Also referred to a 'params' elsewhere in the code.
*/
const std::unordered_map<std::string, runtime::NDArray>& const_name_to_ndarray() const {
return const_name_to_ndarray_;
}

/*!
* \brief Return the constant names in order they were encountered during translation.
*/
const Array<String>& const_names() const { return const_names_; }

/*!\brief Return the generated json. */
std::string GetJSON() {
Expand Down Expand Up @@ -245,11 +258,14 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
return memo_[GetRef<Expr>(vn)];
}

std::vector<JSONGraphNodeEntry> VisitExpr_(const ConstantNode* cn) {
std::string name = symbol_ + "_const_" + std::to_string(params_.size());
params_.push_back(name);
auto node = std::make_shared<JSONGraphNode>(name, "const" /* op_type_ */);
return AddNode(node, GetRef<Expr>(cn));
std::vector<JSONGraphNodeEntry> VisitExpr_(const ConstantNode* constant_node) {
std::string name = symbol_ + "_const_" + std::to_string(const_names_.size());
VLOG(1) << "Will require parameter '" << name << "' to be available at runtime";
ICHECK_EQ(const_name_to_ndarray_.count(name), 0);
const_name_to_ndarray_.emplace(name, constant_node->data);
const_names_.push_back(name);
auto node = std::make_shared<JSONGraphNode>(name, /*op_type=*/"const");
return AddNode(node, GetRef<Expr>(constant_node));
}

std::vector<JSONGraphNodeEntry> VisitExpr_(const TupleNode* tn) {
Expand Down Expand Up @@ -340,8 +356,17 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
std::vector<JSONGraphObjectPtr> nodes_;
/*! \brief Output of the JSON graph. */
std::vector<JSONGraphNodeEntry> heads_;
/*! \brief The list of required constants. */
Array<String> params_;
/*!
* \brief A map from constant names to NDArrays for each Constant encountered during
* translation to JSON. The JSON will record only the constant name. The actual NDArray must
* be made available at runtime from a ConstLoaderModule.
*/
std::unordered_map<std::string, runtime::NDArray> const_name_to_ndarray_;
/*!
* \brief The domain of the above map, but in order the constants were encountered during
* translation.
*/
Array<String> const_names_;
};

} // namespace contrib
Expand Down
21 changes: 15 additions & 6 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ namespace cutlass {

namespace {

/*! \brief Return the "cutlass" Target instance to use to guide compilation. */
Target GetCutlassTarget() {
Target target = Target::Current(/*allow_not_defined=*/true);
if (!target.defined() || target->kind->name != "cutlass") {
// Use the default CUTLASS compilation options.
target = Target("cutlass");
}
return target;
}

using Str2StrMap = std::unordered_map<std::string, std::string>;

static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"},
Expand Down Expand Up @@ -899,14 +909,13 @@ transform::Pass CompileForCutlassImpl() {
VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod);
const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass");
ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function";
Optional<Target> opt_cutlass_target = Target::Current();
ICHECK(opt_cutlass_target.defined()) << "Expecting Target::Current to be available";
ICHECK_EQ(opt_cutlass_target.value()->kind->name, "cutlass");
runtime::Module runtime_mod = (*pf)(mod, opt_cutlass_target.value());
Target target = GetCutlassTarget();
runtime::Module runtime_mod = (*pf)(mod, target);
Array<runtime::Module> external_mods =
mod->GetAttr<Array<runtime::Module>>("external_mods", Array<runtime::Module>()).value();
mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods, Array<runtime::Module>())
.value();
external_mods.push_back(runtime_mod);
return WithAttr(mod, "external_mods", external_mods);
return WithAttr(mod, tvm::attr::kExternalMods, external_mods);
};
return tvm::transform::CreateModulePass(pass_func, 0, "CompileForCutlass", {});
}
Expand Down
8 changes: 6 additions & 2 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,11 +585,15 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) {
DNNLJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();

// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.

const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate");
ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
auto mod = (*pf)(func_name, graph_json, params);
auto mod = (*pf)(func_name, graph_json, serializer.const_names());
return mod;
#else
DNNLModuleCodegen dnnl;
Expand Down
Loading

0 comments on commit 5de3adf

Please sign in to comment.