Skip to content

Commit

Permalink
[checkpoint] externs treated as call_lowered so can see dyn shape fun…
Browse files Browse the repository at this point in the history
…ction info

It compiles something, but not the same as the original.
  • Loading branch information
mbs-octoml committed Nov 18, 2021
1 parent 723e18f commit 845f516
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 128 deletions.
9 changes: 8 additions & 1 deletion include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class Executable : public ModuleNode {
*/
std::string GetVirtualDevices() const;

/*!
* \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable.
* These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by
* a BYOC external codegen.
*/
std::string GetPrimitives() const;

/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
Expand Down Expand Up @@ -201,7 +208,7 @@ class Executable : public ModuleNode {
int host_device_index = -1;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
/*! \brief A map from globals (as strings) to their index in the Relay function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function's global name (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
Expand Down
25 changes: 16 additions & 9 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, mod):
self._get_bytecode = self.mod["get_bytecode"]
self._get_constants = self.mod["get_constants"]
self._get_virtual_devices = self.mod["get_virtual_devices"]
self._get_primitives = self.mod["get_primitives"]
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
Expand Down Expand Up @@ -257,6 +258,12 @@ def virtual_devices(self):
"""Returns a human-readable description of all the (virtual) devices in the executable."""
return self._get_virtual_devices()

@property
def primitive(self):
"""Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the
executable"""
return self._get_primitives()

@property
def globals(self):
"""Get the globals used by the Relay VM executable.
Expand Down Expand Up @@ -522,15 +529,15 @@ def get_input_index(self, input_name, func_name="main"):
return self._get_input_index(input_name, func_name)

def benchmark(
self,
device,
*args,
func_name="main",
repeat=5,
number=5,
min_repeat_ms=None,
end_to_end=False,
**kwargs,
self,
device,
*args,
func_name="main",
repeat=5,
number=5,
min_repeat_ms=None,
end_to_end=False,
**kwargs,
):
"""Calculate runtime of a function by repeatedly calling it.
Expand Down
193 changes: 91 additions & 102 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
* nullptr if none.
*/
BaseFunc ResolveToPrimitive(const Expr& expr) {
// NOTE: We can't assume expr->checked_type_ is defined, so can't early exit for first-order
// expressions.
if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
if (!module_->ContainGlobalVar(global_var_node->name_hint)) {
// TODO(mbs): extern function cleanup
Expand All @@ -518,8 +520,12 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
return GetRef<tir::PrimFunc>(prim_func_node);
} else if (const auto* var_node = expr.as<VarNode>()) {
auto itr = primitive_functions_.find(GetRef<Var>(var_node));
ICHECK(itr != primitive_functions_.end());
return itr->second;
if (itr == primitive_functions_.end()) {
// Not bound to a function.
return {};
} else {
return itr->second;
}
} else if (const auto* function_node = expr.as<FunctionNode>()) {
if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
// Not marked as primitive by FuseOps.
Expand All @@ -545,108 +551,88 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
*/
Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type> type_args, Span span,
Target target) {
CCacheKey key = CCacheKey(func, target);
CachedFunc cfunc = compiler_->Lower(key, module_name_);
ICHECK(cfunc.defined());

auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined()) {
// BYOC flow.
CCacheKey key = CCacheKey(func, target);
CachedFunc ext_func = compiler_->Lower(key, module_name_);
ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
<< PrettyPrint(ext_func->prim_fn_var);

// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
Map<GlobalVar, tir::PrimFunc> prim_fns;
relay::Function func_with_metadata = func;
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, ext_func->target);

// Provide a callback hook which allows one-level up code generators to
// act when we process a function.
this->process_fn_(func_with_metadata);

// TODO(mbs): Dynamic shapes?
// TODO(@mbs, electriclilies): Make extern functions explicit
return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, span);
} else {
// Non-External Relay Function
CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compiler_->Lower(key, module_name_);

// Collect all the lowered functions produced for this primitive function.
// Special handling for "device_copy": No lowered definition is returned (indeed the
// lowered_func->funcs is empty).
// TODO(mbs): "device_copy" cleanup
Map<GlobalVar, tir::PrimFunc> prim_fns;
Array<GlobalVar> all_prim_fn_vars;
for (const auto& pair : lowered_func->funcs->functions) {
CHECK(pair.second.as<tir::PrimFuncNode>())
// Add some metadata on top of the *original function* and invoke the callback so it can
// be captured.
// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
Map<GlobalVar, tir::PrimFunc> prim_fns;
Array<GlobalVar> all_prim_fn_vars;
for (const auto& pair : cfunc->funcs->functions) {
if (opt_compiler) {
// We expect just the original func but with just the ExternalSymbol attribute signalling
// the function (will be) compiled externally.
ICHECK(pair.second.as<FunctionNode>())
<< PrettyPrint(pair.first) << " must be bound to an (external) Function";
} else {
// We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive
// (and the rest in support of that via tir::Calls).
ICHECK(pair.second.as<tir::PrimFuncNode>())
<< PrettyPrint(pair.first) << " must be bound to a PrimFunc";
prim_fns.Set(pair.first, Downcast<tir::PrimFunc>(pair.second));
all_prim_fn_vars.push_back(pair.first);
}
}
Function func_with_metadata = func;
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", cfunc->prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target);
this->process_fn_(func_with_metadata);

auto call_lowered_attrs = make_object<CallLoweredAttrs>();

// Non-External Relay Function
// TODO(mbs): "reshape" cleanup.
if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
}

// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
relay::Function func_with_metadata = func;
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target);

// Provide a callback hook which allows one-level up code generators to
// act when we process a function.
this->process_fn_(func_with_metadata);

auto call_lowered_attrs = make_object<CallLoweredAttrs>();
// TODO(mbs): "reshape" cleanup.
if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
}

// Special handling for "device_copy": Look inside the primitive Function to see if it
// is a device_copy. If so capture source and destination attributes so downstream passes
// can follow along.
// TODO(mbs): "device_copy" cleanup
DeviceCopyProps props = GetDeviceCopyProps(func->body);
if (props.body.defined()) {
call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope);
call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope);
}
// Special handling for "device_copy": Look inside the primitive Function to see if it
// is a device_copy. If so capture source and destination attributes so downstream passes
// can follow along.
// TODO(mbs): "device_copy" cleanup
DeviceCopyProps props = GetDeviceCopyProps(func->body);
if (!opt_compiler && props.body.defined()) {
call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope);
call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope);
}

call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);

if (IsDynamic(func->ret_type)) {
// Also lower the dynamic shape function.
// Shape function keys use the underlying primitive function as their 'function',
// but the generic 'cpu' target as the target since all shape functions run
// on the host cpu irrespective of where the primitive runs.
// TODO(mbs): Cleanup target handling.
Target shape_target = host_se_scope_->target;
CCacheKey shape_key(func, shape_target);
CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
// Capture the shape function's global var and parameters 'states' in call
// annotations so calling convention can be recovered.
// TODO(mbs): Capture all this as part of a 'call into TIR' construct once available.
// The way the shape function calling convention is derived and passed to call sites
// via the 'parameter states' could be improved.
call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
call_lowered_attrs->metadata.Set("prim_shape_fn_states",
lowered_shape_func->shape_func_param_states);
call_lowered_attrs->metadata.Set(
"prim_shape_fn_num_inputs",
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
call_lowered_attrs->metadata.Set(
"prim_shape_fn_num_outputs",
Integer(static_cast<int>(lowered_shape_func->outputs.size())));
Array<GlobalVar> all_prim_shape_fn_vars;
for (const auto& pair : lowered_shape_func->funcs->functions) {
CHECK(pair.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
all_prim_shape_fn_vars.push_back(pair.first);
}
call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);

if (IsDynamic(func->ret_type)) {
// Also lower the companion dynamic shape function.
// Shape function keys use the underlying primitive function as their 'function',
// but the generic 'cpu' target as the target since all shape functions run
// on the host cpu irrespective of where the primitive runs.
CCacheKey shape_key(func, host_se_scope_->target);
CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);

// Capture the shape function's global var and parameters 'states' in call
// annotations so calling convention can be recovered.
// TODO(mbs): Shape cleanup.
call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
call_lowered_attrs->metadata.Set("prim_shape_fn_states",
lowered_shape_func->shape_func_param_states);
call_lowered_attrs->metadata.Set(
"prim_shape_fn_num_inputs", Integer(static_cast<int>(lowered_shape_func->inputs.size())));
call_lowered_attrs->metadata.Set(
"prim_shape_fn_num_outputs",
Integer(static_cast<int>(lowered_shape_func->outputs.size())));
Array<GlobalVar> all_prim_shape_fn_vars;
for (const auto& pair : lowered_shape_func->funcs->functions) {
CHECK(pair.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
all_prim_shape_fn_vars.push_back(pair.first);
}
return CallLowered(lowered_func->prim_fn_var, visited_args, Attrs(call_lowered_attrs),
type_args, span);
call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
}

return CallLowered(cfunc->prim_fn_var, std::move(visited_args), Attrs(call_lowered_attrs),
type_args, std::move(span));
}

std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
Expand All @@ -655,7 +641,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
BaseFunc prim_func = ResolveToPrimitive(new_value);

if (prim_func.defined() && !prim_func->IsInstance<tir::PrimFuncNode>()) {
// Remember let var is bound (possibly indirectly) to a non-tir primitive.
// Remember let var is bound (possibly indirectly) to a primitive.
Function func = Downcast<Function>(prim_func);
primitive_functions_.emplace(var, func);
}
Expand All @@ -672,8 +658,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
}

Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
// Nothing to lower inside primitive functions.
if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
function_node->GetAttr<String>(attr::kExternalSymbol)) {
// Nothing to lower inside primitive/external functions.
return GetRef<Function>(function_node);
} else {
return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node);
Expand All @@ -682,8 +669,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {

Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
// Passes before lowering might insert a call_lowered to call a function that has already
// been lowered. Therefore we might see call_lowered ops here, but we don't need to do anything
// because ResolveToPrimitive returns null for all calls where the call_node->op is an OpNode
// been lowered. Therefore we might see call_lowered ops here, but we don't need to do
// anything because ResolveToPrimitive returns null for all calls where the call_node->op is
// an OpNode
Call call = GetRef<Call>(call_node);

// Look for (indirect) calls to primitives.
Expand Down Expand Up @@ -1056,8 +1044,6 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr
// retrieved. They may, however, have been renamed.
compiler->AddExterns(updated_module);

VLOG(1) << "rewritten but without lowered functions:" << std::endl << PrettyPrint(updated_module);

// Add the lowered functions.
IRModule lowered_module = compiler->GetLoweredFunctions();
for (const auto& kv : lowered_module->functions) {
Expand All @@ -1066,7 +1052,9 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr

VLOG(1) << "rewritten with lowered functions:" << std::endl << PrettyPrint(updated_module);

// Invoke external codegen for all Functions in the cache tagged with "Compiler".
// Invoke external codegen for all Functions in the cache tagged with "Compiler", and
// annotate the module with the resulting runtime modules.
// TODO(mbs): runtime modules should be first class rather than attributes.
Array<runtime::Module> external_modules = compiler->LowerExternalFunctions();

// Annotate the module with C Device API context mapping (this is until we have Target's
Expand All @@ -1080,6 +1068,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr

if (backend::IsAutoSchedulerEnabled()) {
// Capture all the operator weights.
// TODO(mbs): Constants should be represented as global defns in the IRModule.
updated_module = WithAttr(updated_module, "op_weights", compiler->GetOpWeights());
}

Expand Down
5 changes: 4 additions & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,10 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host)

VLOG(1) << std::endl
<< "-------------------------------------------------" << std::endl
<< exec_->GetVirtualDevices() << exec_->GetConstants() << exec_->GetBytecode()
<< exec_->GetVirtualDevices() //
<< exec_->GetConstants() //
<< exec_->GetPrimitives() //
<< exec_->GetBytecode() //
<< "-------------------------------------------------";

if (backend::IsAutoSchedulerEnabled()) {
Expand Down
5 changes: 5 additions & 0 deletions src/relay/transforms/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
// allowed.
for (const auto& kv : mod_->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
function_node->GetAttr<String>(attr::kExternalSymbol)) {
// Ignore primitive and external functions.
continue;
}
// Everything of interest will be recorded in the purity maps so we ignore the result.
(void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node));
}
Expand Down
Loading

0 comments on commit 845f516

Please sign in to comment.