Skip to content

Commit

Permalink
Prepare for switching VM to LowerTEPass.
Browse files Browse the repository at this point in the history
This is a grab bag of fallout changes from switching the VM to use LoweTEPass
which can be ealy split out of the main apache#9483 PR.

- AnnotateSpans can be used from C++ (though, unfortunately, it didn't help
  me with debugging since spans are universally dropped in most passes).
- Can get a human readable dump of the VM's PackedFunc names and indexes for
  debugging.
- If TVM_LOG_DEBUG defined then include types and ids of GlobalVars. I had
  a lot of difficulty tracking down where duplicate GlobalVars for the same
  name_hint were getting created and propagated.
- GetCallLoweredProps follows same API as GetDeviceCopy and GetOnDevice
  where will return 'null' properties if call/expr is not of call_lowered
  form. Mildly more convenient, though switching all the above to ICHECK
  and push 'if (op == the relevant op)' into all use sites would also be just
  fine.
- Misc VLOG improvements made while tracking down issues in apache#9483.
- Don't attach host targets to the CompilationConfig's 'primitive_targets' array.
  Since Targets and SEScopes are compared by pointer equality, and the same Target
  with and without a host are distinct objects, this was causing unnecessary copies
  in code which is already dealing with the explicit host_target or host_se_scope
  anyway. I've left the hosts in the legacy_target_map. (The sooner we sort out
  multi-target compilation and hosts the better!)
  • Loading branch information
mbs-octoml committed Nov 23, 2021
1 parent 34ea319 commit 50898e1
Show file tree
Hide file tree
Showing 25 changed files with 261 additions and 109 deletions.
10 changes: 9 additions & 1 deletion include/tvm/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
#ifndef TVM_PARSER_PARSER_H_
#define TVM_PARSER_PARSER_H_
/*!
* \file parser.h
* \file include/tvm/parser/parser.h
* \brief A parser for TVM IR.
*/
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

Expand All @@ -39,6 +40,13 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte
const Optional<IRModule>& init_module = Optional<IRModule>(),
const MetaTable& init_meta_table = MetaTable());

/*!
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
* modules constructed programaticaly rather than textually.
*/
transform::Pass AnnotateSpans();

} // namespace parser
} // namespace tvm

Expand Down
11 changes: 9 additions & 2 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class Executable : public ModuleNode {
*/
std::string GetVirtualDevices() const;

/*!
* \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable.
* These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by
* a BYOC external codegen.
*/
std::string GetPrimitives() const;

/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
Expand Down Expand Up @@ -201,9 +208,9 @@ class Executable : public ModuleNode {
int host_device_index = -1;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
/*! \brief A map from globals (as strings) to their index in the Relay function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
/*! \brief A mapping from the packed function's global name (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
*/
std::unordered_map<std::string, Index> primitive_map;
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ class SEScope : public ObjectRef {
return SEScope(device.device_type, device.device_id, std::move(target));
}

/*! \brief Returns the \p SEScope for \p target. */
static SEScope ForTarget(Target target) {
return SEScope(static_cast<DLDeviceType>(target->kind->device_type), /*virtual_device_id=*/0,
std::move(target));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
MemoryScope memory_scope) {
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, mod):
self._get_bytecode = self.mod["get_bytecode"]
self._get_constants = self.mod["get_constants"]
self._get_virtual_devices = self.mod["get_virtual_devices"]
self._get_primitives = self.mod["get_primitives"]
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
Expand Down Expand Up @@ -257,6 +258,12 @@ def virtual_devices(self):
"""Returns a human-readable description of all the (virtual) devices in the executable."""
return self._get_virtual_devices()

@property
def primitive(self):
"""Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the
executable"""
return self._get_primitives()

@property
def globals(self):
"""Get the globals used by the Relay VM executable.
Expand Down
13 changes: 8 additions & 5 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
ICHECK_EQ((*it).second, var);
} else {
ICHECK(global_var_map_.count(var->name_hint) == 0)
<< "Duplicate global function name " << var->name_hint;
<< "Duplicate global function name " << PrettyPrint(var);
}

global_var_map_.Set(var->name_hint, var);
Expand Down Expand Up @@ -243,7 +243,7 @@ void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData&
if (!update) {
// set global type var map
ICHECK(global_type_var_map_.count(var->name_hint) == 0)
<< "Duplicate global type definition name " << var->name_hint;
<< "Duplicate global type definition name " << PrettyPrint(var);
}
global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type);
Expand All @@ -266,7 +266,7 @@ void IRModuleNode::Remove(const GlobalVar& var) {

BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
ICHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
ICHECK(it != functions.end()) << "There is no definition of " << PrettyPrint(var);
return (*it).second;
}

Expand All @@ -277,7 +277,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const {

TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
ICHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
ICHECK(it != type_definitions.end()) << "There is no definition of " << PrettyPrint(var);
return (*it).second;
}

Expand Down Expand Up @@ -306,6 +306,10 @@ String IRModuleNode::GetUniqueName(const String& name) {
}
}

/*!
* \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs
* ('one') side above the rhs ('two').
*/
struct Renamer : relay::ExprMutator, TypeMutator {
Map<String, GlobalVar> defs;
Map<String, GlobalTypeVar> types;
Expand Down Expand Up @@ -411,7 +415,6 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc
void IRModuleNode::Import(const String& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
std::fstream src_file(path, std::fstream::in);
std::string file_contents{std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>()};
Expand Down
27 changes: 17 additions & 10 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,8 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(9) << "ParseModule";
VLOG_CONTEXT << "ParseModule";
VLOG(9) << "parsing and type-checking " << file_name;
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
Expand Down Expand Up @@ -1952,15 +1953,21 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr")
return ParseExpr(file_name, file_content);
});

TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() {
return CreateModulePass(
[](const IRModule& mod, const PassContext& ctx) {
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
},
0, "AnnotateSpans", {});
});
/*!
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
* modules constructed programaticaly rather than textually.
*/
Pass AnnotateSpans() {
auto pass_func = [](const IRModule& mod, const PassContext& ctx) {
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
};
return CreateModulePass(pass_func, 0, "AnnotateSpans", {});
}

TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans);

} // namespace parser
} // namespace tvm
12 changes: 11 additions & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,17 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}

Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); }
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
Doc doc;
doc << "@" << op->name_hint;
#if TVM_LOG_DEBUG
if (op->checked_type_.defined()) {
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
}
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
#endif
return doc;
}

Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }

Expand Down
10 changes: 8 additions & 2 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ Doc TextPrinter::PrintMod(const IRModule& mod) {
if (kv.second.as<relay::FunctionNode>()) {
std::ostringstream os;
os << "def @" << kv.first->name_hint;
#if TVM_LOG_DEBUG
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint << " = ";
doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
doc << "@" << kv.first->name_hint;
#if TVM_LOG_DEBUG
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
}
doc << Doc::NewLine();
}
Expand Down
11 changes: 6 additions & 5 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
Expr func;
Array<Expr> args;

if (call_node->op == CallLoweredOp()) {
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
func = call_lowered_props.lowered_func;
args = call_lowered_props.arguments;
} else { // Relay functions that have not been lowered and lowered extern functions
Expand Down Expand Up @@ -516,10 +516,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn), call_node->args, {}};
} else {
ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
call_lowered_props = GetCallLoweredProps(call_node);
ICHECK(call_lowered_props.lowered_func.defined())
<< "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
for (const auto& arg : call_lowered_props.arguments) {
VisitExpr(arg);
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
std::vector<GraphNodeRef> inputs;
std::string func_name;

if (call->op == CallLoweredOp()) {
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
// Extract function and arguments from the call_lowered op
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

func_name = call_lowered_props.lowered_func->name_hint;

Expand Down
5 changes: 3 additions & 2 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
}

ObjectRef VisitExpr_(const CallNode* call_node) final {
if (call_node->op == CallLoweredOp()) { // Special case: Call a lowered TIR function.
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
// Special case: Call a lowered TIR function.

// Evaluate only function args
std::vector<ObjectRef> args;
Expand Down
35 changes: 27 additions & 8 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class TECompilerImpl : public TECompilerNode {
IRModule lowered_mod = lowered_func->cached_func->funcs;

// Annotate functions with their target and put them in the return module
for (auto kv : lowered_mod->functions) {
for (const auto& kv : lowered_mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;

Expand All @@ -114,6 +114,7 @@ class TECompilerImpl : public TECompilerNode {
}
}
}

// Extract lowered dynamic shape functions from the shape cache
for (const auto& it : shape_func_cache_) {
auto source_func = it.first;
Expand All @@ -129,6 +130,7 @@ class TECompilerImpl : public TECompilerNode {
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
}
}

return mod;
}

Expand Down Expand Up @@ -202,10 +204,16 @@ class TECompilerImpl : public TECompilerNode {
private:
// implement lowered func
CCacheValue LowerInternal(const CCacheKey& key, std::function<String(String)> mangle_fn) {
VLOG(1) << "lowering:" << std::endl
<< PrettyPrint(key->source_func) << std::endl
<< "for target:" << std::endl
<< key->target->ToDebugString();
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = cache_.find(key);
if (it != cache_.end()) {
VLOG(1) << "already lowered to:" << std::endl
<< PrettyPrint(it->second->cached_func->prim_fn_var);
it->second->use_count += 1;
if (it->second->cached_func.defined()) return it->second;
value = it->second;
Expand All @@ -232,6 +240,11 @@ class TECompilerImpl : public TECompilerNode {
// Collect these here as it's removed in LowerExternalFunctions()
std::string codegen_name = key->source_func->GetAttr<String>(attr::kCompiler).value();
device_contexts_.Set(global_var, codegen_name);
VLOG(1) << "preparing to use external codegen '" << codegen_name
<< "' with name:" << std::endl
<< PrettyPrint(value->cached_func->prim_fn_var) << std::endl
<< "and definitions:" << std::endl
<< PrettyPrint(value->cached_func->funcs);
return value;
}

Expand Down Expand Up @@ -266,11 +279,19 @@ class TECompilerImpl : public TECompilerNode {
cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
}
value->cached_func = cfunc;
VLOG(1) << "lowered to name:" << std::endl
<< PrettyPrint(value->cached_func->prim_fn_var) << std::endl
<< "with definitions:" << std::endl
<< PrettyPrint(value->cached_func->funcs);
return value;
}

// implement lowered shape func
CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
VLOG(1) << "lowering dynamic shape function:" << std::endl
<< PrettyPrint(key->source_func) << std::endl
<< "for target:" << std::endl
<< key->target->ToDebugString();
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = shape_func_cache_.find(key);
Expand All @@ -295,6 +316,10 @@ class TECompilerImpl : public TECompilerNode {
});

value->cached_func = cached_func;
VLOG(1) << "lowered to name:" << std::endl
<< PrettyPrint(value->cached_func->prim_fn_var) << std::endl
<< "with definitions:" << std::endl
<< PrettyPrint(value->cached_func->funcs);
return value;
}

Expand Down Expand Up @@ -480,11 +505,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {

} else {
// Non-External Relay Function
VLOG(1) << "lowering to target " << target->ToDebugString() << " for primitive:\n"
<< PrettyPrint(func);
CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compiler_->Lower(key, module_name_);
VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'";

// Collect all the lowered functions produced for this primitive function.
Map<GlobalVar, tir::PrimFunc> prim_fns;
Expand All @@ -493,7 +515,6 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
all_prim_fn_vars.push_back(prim_fn.first);
VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'";
}

// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
Expand Down Expand Up @@ -529,8 +550,6 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
// on the host cpu irrespective of where the primitive runs.
// TODO(mbs): Cleanup target handling.
Target shape_target("llvm");
VLOG(1) << "lowering to target " << shape_target->ToDebugString()
<< " for dynamic shape function for primitive";
CCacheKey shape_key(func, shape_target);
CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
// Capture the shape function's global var and parameters 'states' in call
Expand Down Expand Up @@ -925,7 +944,7 @@ void UpdateFunctionMetadata(BaseFunc func,
std::move(workspace_sizes), std::move(io_sizes), std::move(constant_sizes),
std::move(tir_primfuncs), std::move(relay_primfuncs));

VLOG(1) << "FunctionInfo: " << prim_fn_var.value()->name_hint << " = " << PrettyPrint(fi);
VLOG(1) << "FunctionInfo: " << PrettyPrint(prim_fn_var.value()) << " = " << PrettyPrint(fi);

// The primitive function name here corresponds to the string we will use to generate
// this Relay function at the low level.
Expand Down
Loading

0 comments on commit 50898e1

Please sign in to comment.