Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Prepare for switching VM to LowerTEPass. #9550

Merged
merged 1 commit into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/tvm/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
#ifndef TVM_PARSER_PARSER_H_
#define TVM_PARSER_PARSER_H_
/*!
* \file parser.h
* \file include/tvm/parser/parser.h
* \brief A parser for TVM IR.
*/
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

Expand All @@ -39,6 +40,13 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte
const Optional<IRModule>& init_module = Optional<IRModule>(),
const MetaTable& init_meta_table = MetaTable());

/*!
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
* modules constructed programaticaly rather than textually.
*/
transform::Pass AnnotateSpans();

} // namespace parser
} // namespace tvm

Expand Down
11 changes: 9 additions & 2 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class Executable : public ModuleNode {
*/
std::string GetVirtualDevices() const;

/*!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what exactly the return value is here given its just a std::string.

Copy link
Contributor Author

@mbs-octoml mbs-octoml Nov 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My I please do that in the sequel so we can merge and I take 9483 out of draft? These are just human readable pretty printed descriptions so I can decipher the instructions. Thx.

* \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable.
* These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by
* a BYOC external codegen.
*/
std::string GetPrimitives() const;

/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
Expand Down Expand Up @@ -201,9 +208,9 @@ class Executable : public ModuleNode {
int host_device_index = -1;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
/*! \brief A map from globals (as strings) to their index in the Relay function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
/*! \brief A mapping from the packed function's global name (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
*/
std::unordered_map<std::string, Index> primitive_map;
Expand Down
22 changes: 15 additions & 7 deletions include/tvm/target/compilation_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {

/*!
* \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to
* compile a Relay module. All centralizes any setup and validation logic needed to transition
* compile a Relay module. Centralizes any setup and validation logic needed to transition
* from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly
* (eg a a list of \p Targets) to the configuration.
*
Expand All @@ -49,9 +49,12 @@ namespace tvm {
class CompilationConfigNode : public Object {
public:
/*!
* \brief The legacy targets map, mapping device type to \p Targets. Does not include any
* entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType,
* though we want to get rid of that limitation.
* \brief The legacy targets map, mapping device type to the corresponding \p Target to use
* when compiling primitive functions. Does not include an entry for the host target, however
* each \p Target in this map will have it's \p host field set to the \p host_target.
*
* Currently we require at most one \p Target per \p DLDeviceType, though we want to get rid of
* that limitation.
*
* CAUTION: Since keys are \p Integers they are compared by object equality not integer
* value.
Expand All @@ -63,13 +66,18 @@ class CompilationConfigNode : public Object {
/*!
* \brief The host target. Used for 'scalar' data and code (such as shapes and shape
* functions) and residual Relay expressions and data (such as conditionals and ADTs).
*
* Note that it is possible for a \p Target used for primitive operations to be structurally
* equal to the host \p Target (up to the \p host field.) However the \p Target objects will
* be distinct, and can be used as keys within a \p Map without collision.
*/
Target host_target;

/*!
* \brief Vector of all available targets for primitive operators. May contain a \p Target
* for the same device type as for the \p host_target, however the \p host_target should
* be preferred for all host computations and data.
* \brief Vector of all available \p Targets for compiling primitive operators. May contain
* a \p Target for the same device type as for the \p host_target, however the \p host_target
* should be used for all host computations and data. Each \p Target will have \p host_target
* as its host.
*/
Array<Target> primitive_targets;

Expand Down
6 changes: 6 additions & 0 deletions include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ class SEScope : public ObjectRef {
return SEScope(device.device_type, device.device_id, std::move(target));
}

/*! \brief Returns the \p SEScope for \p target. */
static SEScope ForTarget(Target target) {
return SEScope(static_cast<DLDeviceType>(target->kind->device_type), /*virtual_device_id=*/0,
std::move(target));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
MemoryScope memory_scope) {
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, mod):
self._get_bytecode = self.mod["get_bytecode"]
self._get_constants = self.mod["get_constants"]
self._get_virtual_devices = self.mod["get_virtual_devices"]
self._get_primitives = self.mod["get_primitives"]
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
Expand Down Expand Up @@ -257,6 +258,12 @@ def virtual_devices(self):
"""Returns a human-readable description of all the (virtual) devices in the executable."""
return self._get_virtual_devices()

@property
def primitive(self):
"""Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the
executable"""
return self._get_primitives()

@property
def globals(self):
"""Get the globals used by the Relay VM executable.
Expand Down
13 changes: 8 additions & 5 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
ICHECK_EQ((*it).second, var);
} else {
ICHECK(global_var_map_.count(var->name_hint) == 0)
<< "Duplicate global function name " << var->name_hint;
<< "Duplicate global function name " << PrettyPrint(var);
}

global_var_map_.Set(var->name_hint, var);
Expand Down Expand Up @@ -243,7 +243,7 @@ void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData&
if (!update) {
// set global type var map
ICHECK(global_type_var_map_.count(var->name_hint) == 0)
<< "Duplicate global type definition name " << var->name_hint;
<< "Duplicate global type definition name " << PrettyPrint(var);
}
global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type);
Expand All @@ -266,7 +266,7 @@ void IRModuleNode::Remove(const GlobalVar& var) {

BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
ICHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
ICHECK(it != functions.end()) << "There is no definition of " << PrettyPrint(var);
return (*it).second;
}

Expand All @@ -277,7 +277,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const {

TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
ICHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
ICHECK(it != type_definitions.end()) << "There is no definition of " << PrettyPrint(var);
return (*it).second;
}

Expand Down Expand Up @@ -306,6 +306,10 @@ String IRModuleNode::GetUniqueName(const String& name) {
}
}

/*!
* \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs
* ('one') side above the rhs ('two').
*/
struct Renamer : relay::ExprMutator, TypeMutator {
Map<String, GlobalVar> defs;
Map<String, GlobalTypeVar> types;
Expand Down Expand Up @@ -411,7 +415,6 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc
void IRModuleNode::Import(const String& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
std::fstream src_file(path, std::fstream::in);
std::string file_contents{std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>()};
Expand Down
27 changes: 17 additions & 10 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,8 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(9) << "ParseModule";
VLOG_CONTEXT << "ParseModule";
VLOG(9) << "parsing and type-checking " << file_name;
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
Expand Down Expand Up @@ -1952,15 +1953,21 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr")
return ParseExpr(file_name, file_content);
});

TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() {
return CreateModulePass(
[](const IRModule& mod, const PassContext& ctx) {
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
},
0, "AnnotateSpans", {});
});
/*!
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
* modules constructed programaticaly rather than textually.
*/
Pass AnnotateSpans() {
auto pass_func = [](const IRModule& mod, const PassContext& ctx) {
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
};
return CreateModulePass(pass_func, 0, "AnnotateSpans", {});
}

TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans);

} // namespace parser
} // namespace tvm
12 changes: 11 additions & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,17 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}

Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); }
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
Doc doc;
doc << "@" << op->name_hint;
#if TVM_LOG_DEBUG
if (op->checked_type_.defined()) {
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
}
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
#endif
return doc;
}

Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }

Expand Down
20 changes: 18 additions & 2 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,29 @@ Doc TextPrinter::PrintMod(const IRModule& mod) {
if (kv.second.as<relay::FunctionNode>()) {
std::ostringstream os;
os << "def @" << kv.first->name_hint;
#if TVM_LOG_DEBUG
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint << " = ";
doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
doc << "@" << kv.first->name_hint;
#if TVM_LOG_DEBUG
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
}
doc << Doc::NewLine();
}
#if TVM_LOG_DEBUG
// attributes
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
doc << "attributes {" << Doc::NewLine();
for (const auto& kv : mod->attrs->dict) {
doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine();
}
doc << "}" << Doc::NewLine();
}
#endif
return doc;
}

Expand Down
33 changes: 27 additions & 6 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
Expr func;
Array<Expr> args;

if (call_node->op == CallLoweredOp()) {
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
func = call_lowered_props.lowered_func;
args = call_lowered_props.arguments;
} else { // Relay functions that have not been lowered and lowered extern functions
Expand Down Expand Up @@ -516,10 +516,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn), call_node->args, {}};
} else {
ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
call_lowered_props = GetCallLoweredProps(call_node);
ICHECK(call_lowered_props.lowered_func.defined())
<< "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
for (const auto& arg : call_lowered_props.arguments) {
VisitExpr(arg);
}
Expand Down Expand Up @@ -717,6 +718,14 @@ class AOTExecutorCodegen : public MixedModeVisitor {
: mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)) {}

LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) {
VLOG_CONTEXT << "AOT";
for (const auto& kv : targets_) {
VLOG(1) << "target: " << kv.second->ToDebugString();
}
if (target_host_.defined()) {
VLOG(1) << "target host: " << target_host_->ToDebugString();
}

Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
Integer workspace_byte_alignment =
Expand Down Expand Up @@ -793,10 +802,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
}

// Build the TIR IRModule for the AOT function
// Build the TIR IRModule for the main AOT function
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
IRModule mod_run(symbol_map, {}, {}, {}, mod->attrs);
VLOG(1) << "main module:" << std::endl << PrettyPrint(mod_run);

// Apply storage rewrite pass to the runner function to do memory planning
auto storage_rewrite = tir::transform::StorageRewrite();
Expand Down Expand Up @@ -827,12 +837,23 @@ class AOTExecutorCodegen : public MixedModeVisitor {
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";

// This is the point where we separate the functions in the module by target
VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
VLOG(1) << "per-target modules:";
for (const auto& kv : ret.lowered_funcs) {
VLOG(1) << "target:" << std::endl
<< kv.first->ToDebugString() << std::endl
<< "maps to:" << std::endl
<< PrettyPrint(kv.second);
}

ret.external_mods = external_modules.value();

if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
VLOG(1) << "merging main into existing module for host target";
ret.lowered_funcs[target_host_]->Update(mod_run);
} else {
VLOG(1) << "adding main into new module for host target";
ret.lowered_funcs.Set(target_host_, mod_run);
}

Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
std::vector<GraphNodeRef> inputs;
std::string func_name;

if (call->op == CallLoweredOp()) {
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
// Extract function and arguments from the call_lowered op
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

func_name = call_lowered_props.lowered_func->name_hint;

Expand Down
5 changes: 3 additions & 2 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
}

ObjectRef VisitExpr_(const CallNode* call_node) final {
if (call_node->op == CallLoweredOp()) { // Special case: Call a lowered TIR function.
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
// Special case: Call a lowered TIR function.

// Evaluate only function args
std::vector<ObjectRef> args;
Expand Down
Loading