Skip to content

Commit

Permalink
** Use LowerTEPass in VM **
Browse files Browse the repository at this point in the history
We replace use of the TECompiler::{Lower,LowerShapeFunc} methods from the VM's
compiler.cc with LowerTEPass. This clears the way for performing post-lowering
IRModule->IRModule transformations which combine Relay and TIR analysis. In particular,
it will allow us to use the PlanDevices pass to propagate memory scope constraints
across PrimFuncs.

We run LowerTEPass fairly early in the pipeline, which required quite a few passes
to become 'post-lowering friendly'. In particular, ManifestAlloc is now run after
rather than before lowering, and so must now work in a mixed Function/PrimFunc world.

The "vm.shape_func" operator has been removed since a) lowering has already generated
the necessary dynamic shape function, and b) the call to that function can be
represented by an 'ordinary' vm.invoke_tvm_op call.

We worked our way through the following glitches:
 - Lowering was choosing definitional GlobalVars which were not pointer-equal to the
   referential GlobalVars left behind in the rewritten Calls. We fixed that in
   te_compiler.cc, though better would be to push GlobalVars deeper into the
   lowering machinery.
 - device_copy was rewritten to a call to @__copy without any definition. We retain
   it as if it were an 'external'.
 - Calls to already-compiled BYOC functions were indistinguishable from calls
   to (non-primitive) Relay functions. We move them into the call_lowered calling
   convention, and leave behind a Function tagged with "ExternalSymbol". Better would
   be a first-class representatn for externals in the IRModule but one step at a time.
 - Functions with dynamic shapes tagged for BYOC compilation were not tracking their
   connection to their dynamic shape function. We now use exactly the same attributes
   as for non-BYOC primitives.
  • Loading branch information
mbs-octoml committed Nov 19, 2021
1 parent 6832309 commit 47cae68
Show file tree
Hide file tree
Showing 55 changed files with 1,567 additions and 1,202 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,7 @@ Rust language support in TVM includes two parts. 1. The frontend wraps the curre
* [Relay] Fix memory leak in the interpreter (#4155)
* [rpc] use callback func to do send & recv (#4147)
* Add `lift_if_then_else` pass to improve loop partitioning (#3865)
* Decrease the complexity of CalcDep from exponential to linear (#4053)
* Decrease the complexity of UsageVisitor from exponential to linear (#4053)
* [IR] Make iterators compatible with constructors of STL containers (#3624)
* [Relay][Pass] Avoid FoldConstant folding some ops (#4245)
* [Relay][Prelude] More dtypes support in `tensor_t` (#4233)
Expand Down
10 changes: 9 additions & 1 deletion include/tvm/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
#ifndef TVM_PARSER_PARSER_H_
#define TVM_PARSER_PARSER_H_
/*!
* \file parser.h
* \file include/tvm/parser/parser.h
* \brief A parser for TVM IR.
*/
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

Expand All @@ -39,6 +40,13 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte
const Optional<IRModule>& init_module = Optional<IRModule>(),
const MetaTable& init_meta_table = MetaTable());

/*!
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
* modules constructed programaticaly rather than textually.
*/
transform::Pass AnnotateSpans();

} // namespace parser
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
/*!
* \brief Remove the continuation argument of a CPS function.
*
* Note that this only transform the type back into un-CPS form
* Note that this only transform the type back into un-CPS formA
* when there is no higher order input/output.
*
* \param f the function.
Expand Down
51 changes: 29 additions & 22 deletions include/tvm/runtime/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,22 +570,25 @@ TVM_CHECK_FUNC(_NE, !=)
#define TVM_CHECK_BINARY_OP(name, op, x, y) \
if (auto __tvm__log__err = ::tvm::runtime::detail::LogCheck##name(x, y)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< "Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": "
<< ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str() \
<< ": Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": "

#define CHECK(x) \
if (!(x)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< "Check failed: (" #x << ") is false: "
#define CHECK(x) \
if (!(x)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str() << ": Check failed: (" #x \
<< ") is false: "

#define CHECK_LT(x, y) TVM_CHECK_BINARY_OP(_LT, <, x, y)
#define CHECK_GT(x, y) TVM_CHECK_BINARY_OP(_GT, >, x, y)
#define CHECK_LE(x, y) TVM_CHECK_BINARY_OP(_LE, <=, x, y)
#define CHECK_GE(x, y) TVM_CHECK_BINARY_OP(_GE, >=, x, y)
#define CHECK_EQ(x, y) TVM_CHECK_BINARY_OP(_EQ, ==, x, y)
#define CHECK_NE(x, y) TVM_CHECK_BINARY_OP(_NE, !=, x, y)
#define CHECK_NOTNULL(x) \
((x) == nullptr ? ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< "Check not null: " #x << ' ', \
#define CHECK_NOTNULL(x) \
((x) == nullptr ? ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str() \
<< ": Check not null: " #x << ' ', \
(x) : (x)) // NOLINT(*)

#define LOG_IF(severity, condition) \
Expand Down Expand Up @@ -664,28 +667,32 @@ TVM_CHECK_FUNC(_NE, !=)

#define TVM_ICHECK_INDENT " "

#define ICHECK_BINARY_OP(name, op, x, y) \
if (auto __tvm__log__err = ::tvm::runtime::detail::LogCheck##name(x, y)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE << std::endl \
<< TVM_ICHECK_INDENT << "Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": "
#define ICHECK_BINARY_OP(name, op, x, y) \
if (auto __tvm__log__err = ::tvm::runtime::detail::LogCheck##name(x, y)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE << std::endl \
<< TVM_ICHECK_INDENT << ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str() \
<< ": Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": "

#define ICHECK(x) \
if (!(x)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE << TVM_ICHECK_INDENT \
<< "Check failed: (" #x << ") is false: "
#define ICHECK(x) \
if (!(x)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE << TVM_ICHECK_INDENT \
<< ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str() << ": Check failed: (" #x \
<< ") is false: "

#define ICHECK_LT(x, y) ICHECK_BINARY_OP(_LT, <, x, y)
#define ICHECK_GT(x, y) ICHECK_BINARY_OP(_GT, >, x, y)
#define ICHECK_LE(x, y) ICHECK_BINARY_OP(_LE, <=, x, y)
#define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y)
#define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y)
#define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y)
#define ICHECK_NOTNULL(x) \
((x) == nullptr ? ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE \
<< TVM_ICHECK_INDENT << "Check not null: " #x << ' ', \
#define ICHECK_NOTNULL(x) \
((x) == nullptr ? ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< ::tvm::runtime::detail::kTVM_INTERNAL_ERROR_MESSAGE \
<< TVM_ICHECK_INDENT \
<< ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str() \
<< ": Check not null: " #x << ' ', \
(x) : (x)) // NOLINT(*)

} // namespace runtime
Expand Down
11 changes: 9 additions & 2 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class Executable : public ModuleNode {
*/
std::string GetVirtualDevices() const;

/*!
* \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable.
* These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by
* a BYOC external codegen.
*/
std::string GetPrimitives() const;

/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
Expand Down Expand Up @@ -201,9 +208,9 @@ class Executable : public ModuleNode {
int host_device_index = -1;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
/*! \brief A map from globals (as strings) to their index in the Relay function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
/*! \brief A mapping from the packed function's global name (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
*/
std::unordered_map<std::string, Index> primitive_map;
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ class SEScope : public ObjectRef {
return SEScope(device.device_type, device.device_id, std::move(target));
}

/*! \brief Returns the \p SEScope for \p target. */
static SEScope ForTarget(Target target) {
return SEScope(static_cast<DLDeviceType>(target->kind->device_type), /*virtual_device_id=*/0,
std::move(target));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
MemoryScope memory_scope) {
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, mod):
self._get_bytecode = self.mod["get_bytecode"]
self._get_constants = self.mod["get_constants"]
self._get_virtual_devices = self.mod["get_virtual_devices"]
self._get_primitives = self.mod["get_primitives"]
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
Expand Down Expand Up @@ -257,6 +258,12 @@ def virtual_devices(self):
"""Returns a human-readable description of all the (virtual) devices in the executable."""
return self._get_virtual_devices()

@property
def primitive(self):
"""Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the
executable"""
return self._get_primitives()

@property
def globals(self):
"""Get the globals used by the Relay VM executable.
Expand Down
13 changes: 8 additions & 5 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
ICHECK_EQ((*it).second, var);
} else {
ICHECK(global_var_map_.count(var->name_hint) == 0)
<< "Duplicate global function name " << var->name_hint;
<< "Duplicate global function name " << PrettyPrint(var);
}

global_var_map_.Set(var->name_hint, var);
Expand Down Expand Up @@ -242,7 +242,7 @@ void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData&
if (!update) {
// set global type var map
ICHECK(global_type_var_map_.count(var->name_hint) == 0)
<< "Duplicate global type definition name " << var->name_hint;
<< "Duplicate global type definition name " << PrettyPrint(var);
}
global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type);
Expand All @@ -265,7 +265,7 @@ void IRModuleNode::Remove(const GlobalVar& var) {

BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
ICHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
ICHECK(it != functions.end()) << "There is no definition of " << PrettyPrint(var);
return (*it).second;
}

Expand All @@ -276,7 +276,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const {

TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
ICHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
ICHECK(it != type_definitions.end()) << "There is no definition of " << PrettyPrint(var);
return (*it).second;
}

Expand Down Expand Up @@ -305,6 +305,10 @@ String IRModuleNode::GetUniqueName(const String& name) {
}
}

/*!
* \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs
* ('one') side above the rhs ('two').
*/
struct Renamer : relay::ExprMutator, TypeMutator {
Map<String, GlobalVar> defs;
Map<String, GlobalTypeVar> types;
Expand Down Expand Up @@ -410,7 +414,6 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc
void IRModuleNode::Import(const String& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
std::fstream src_file(path, std::fstream::in);
std::string file_contents{std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>()};
Expand Down
27 changes: 17 additions & 10 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,8 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(9) << "ParseModule";
VLOG_CONTEXT << "ParseModule";
VLOG(9) << "parsing and type-checking " << file_name;
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
Expand Down Expand Up @@ -1952,15 +1953,21 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr")
return ParseExpr(file_name, file_content);
});

TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() {
return CreateModulePass(
[](const IRModule& mod, const PassContext& ctx) {
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
},
0, "AnnotateSpans", {});
});
/*!
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
* modules constructed programaticaly rather than textually.
*/
Pass AnnotateSpans() {
auto pass_func = [](const IRModule& mod, const PassContext& ctx) {
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
};
return CreateModulePass(pass_func, 0, "AnnotateSpans", {});
}

TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans);

} // namespace parser
} // namespace tvm
12 changes: 11 additions & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,17 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}

Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); }
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
Doc doc;
doc << "@" << op->name_hint;
#if TVM_LOG_DEBUG
if (op->checked_type_.defined()) {
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
}
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
#endif
return doc;
}

Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }

Expand Down
10 changes: 8 additions & 2 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ Doc TextPrinter::PrintMod(const IRModule& mod) {
if (kv.second.as<relay::FunctionNode>()) {
std::ostringstream os;
os << "def @" << kv.first->name_hint;
#if TVM_LOG_DEBUG
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint << " = ";
doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
doc << "@" << kv.first->name_hint;
#if TVM_LOG_DEBUG
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
}
doc << Doc::NewLine();
}
Expand Down
26 changes: 17 additions & 9 deletions src/relay/analysis/call_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "call_graph.h"

#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/object.h>

Expand All @@ -33,6 +34,8 @@
#include <unordered_set>
#include <vector>

#include "../op/call/call.h"

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -64,9 +67,19 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) {
// post-order visitor will visit each AST node of the current function to
// figure out the dependencies between functions.
PostOrderVisit(func, [&](const Expr& expr) {
if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
auto callee = GetRef<GlobalVar>(gvn);
cg_node->AddCalledGlobal(LookupGlobalVar(callee));
// TODO(mbs): Cleanup shapes functions.
if (const auto* call_node = expr.as<CallNode>()) {
CallLoweredProps props = GetCallLoweredProps(call_node);
if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) {
// We are implicitly calling the shape function *in addition to* the call target.
CallGraphEntry* callee_cg_node =
LookupGlobalVar(Downcast<GlobalVar>(props.attrs.metadata["prim_shape_fn_var"]));
cg_node->AddCalledGlobal(callee_cg_node);
}
} else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
auto callee = GetRef<GlobalVar>(global_var_node);
CallGraphEntry* callee_cg_node = LookupGlobalVar(callee);
cg_node->AddCalledGlobal(callee_cg_node);
}
});
}
Expand All @@ -88,21 +101,16 @@ CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) {
BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const {
ICHECK(module->ContainGlobalVar(var->name_hint))
<< "GlobalVar " << var->name_hint << " not found in the current ir module";
return module->Lookup(var);
return module->Lookup(var->name_hint);
}

// Query the existence of a GlobalVar in the call graph. It creates an entry if
// there is no such node available.
CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) {
ICHECK(gv.defined());

// This inserts an element to the call graph if it is not there yet.
auto& call_graph_node = call_graph_[gv];
if (call_graph_node) return call_graph_node.get();

ICHECK(module->ContainGlobalVar(gv->name_hint))
<< "GlobalVar " << gv->name_hint << " not found in the current ir module";

// Create the node for the inserted entry.
call_graph_node = std::unique_ptr<CallGraphEntry>(new CallGraphEntry(gv));
return call_graph_node.get();
Expand Down
4 changes: 2 additions & 2 deletions src/relay/analysis/call_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ class CallGraphNode : public Object {
GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph = false);

/*!
* \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for
* the GlobalVar if it doesn't exist.
* \brief Returns the \p CallGraphEntry for the global function bound to \p gv. Creates an entry
* if one does not already exist.
*
* \param gv The GlobalVar for query.
*
Expand Down
Loading

0 comments on commit 47cae68

Please sign in to comment.