Skip to content

Commit

Permalink
[checkpoint] Fix GetType recursion, get debug going.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Jul 30, 2021
1 parent f95b22c commit 9b894d7
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
6 changes: 5 additions & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ void handler(int sig) {
*
* If not we can evaluate it directly and don't need to bind it into a fresh module.
*/
class NeedsPreparationVisitor : private ExprVisitor {
class NeedsPreparationVisitor : public ExprVisitor {
public:
bool needs_preparation = false;

Expand Down Expand Up @@ -964,7 +964,11 @@ ObjectRef Interpret(IRModule mod, Expr expr, Device device, Target target) {
signal(SIGBUS, handler);

DLOG(INFO) << "interpreting:\n" << expr << "\nw.r.t. module:\n" << mod;

// If expr is simple enough we can avoid binding it into the module.
NeedsPreparationVisitor visitor;
visitor.VisitExpr(expr);

Expr expr_to_eval;
IRModule mod_and_expr; // default empty
if (visitor.needs_preparation) {
Expand Down
33 changes: 15 additions & 18 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ class LowerTensorExprMutator : public ExprMutator {
device_context_map_(device_ctx_map),
process_fn_(process_fn),
module_name_(module_name),
compiler_(compiler) {}
compiler_(compiler),
debug_op_(Op::Get("debug")) {}

/*!
* \brief Returns the primitive function associated with \p expr, or
Expand All @@ -419,18 +420,7 @@ class LowerTensorExprMutator : public ExprMutator {
Function ResolveToPrimitive(Expr expr) {
if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
BaseFunc base_func = module_->Lookup(GetRef<GlobalVar>(gvn));
if (const FunctionNode* fn = base_func.as<FunctionNode>()) {
if (!fn->HasNonzeroAttr(attr::kPrimitive)) {
DLOG(INFO) << "ignoring, global var is not bound to a primitive function";
return Function();
} else {
DLOG(INFO) << "global var bound to prim function";
return GetRef<Function>(fn);
}
} else {
DLOG(INFO) << "ignoring, global var is not bound to a Relay function";
return Function();
}
return ResolveToPrimitive(base_func);
} else if (const VarNode* vn = expr.as<VarNode>()) {
auto itr = primitive_functions_.find(GetRef<Var>(vn));
if (itr == primitive_functions_.end()) {
Expand All @@ -443,20 +433,25 @@ class LowerTensorExprMutator : public ExprMutator {
if (!fn->HasNonzeroAttr(attr::kPrimitive)) {
DLOG(INFO) << "ignoring, function is not primitive";
return Function();
} else {
DLOG(INFO) << "prim function";
return GetRef<Function>(fn);
}
if (const CallNode* cn = fn->body.as<CallNode>()) {
if (cn->op == debug_op_) {
DLOG(INFO) << "ignoring, primitive function is debug function";
return Function();
}
}
DLOG(INFO) << "prim function";
return GetRef<Function>(fn);
}
DLOG(INFO) << "ignoring other expression";
return Function();
}

/*!
* \brief Lowers the primitive function \p func to TIR for ultimate execution
* on a target with \p device_type. Returns the global var bound to the TIR
* on a target of \p device_type. Returns the global var bound to the TIR
* implementation and attributes to attach to the call to identify it as
* a primitive call.
* a TIR call.
*/
std::pair<GlobalVar, Attrs> LowerFunction(Function func, DLDeviceType device_type) {
Target target;
Expand Down Expand Up @@ -627,6 +622,8 @@ class LowerTensorExprMutator : public ExprMutator {
primitive_functions_;
String module_name_;
TECompiler compiler_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
};

Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map,
Expand Down
21 changes: 20 additions & 1 deletion src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,25 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return ret;
}

// As above, but for a func which may be mutually recursive with itself. We use the
// functions type signature and find a fixed point in one step.
Type GetLetrecType(const BaseFunc& func) {
auto it = type_map_.find(func);
if (it != type_map_.end() && it->second.checked_type.defined()) {
return it->second.checked_type;
}
if (func->checked_type_.defined()) {
ResolvedTypeInfo& rti = type_map_[func];
rti.checked_type = func->checked_type_;
}
Type ret = this->VisitExpr(func);
ICHECK(ret.defined());
KindCheck(ret, mod_, this->diag_ctx);
ResolvedTypeInfo& rti = type_map_[func];
rti.checked_type = ret;
return ret;
}

void EmitFatal(const Diagnostic& diag) { this->diag_ctx.EmitFatal(diag); }

// Visitor Logic
Expand All @@ -210,7 +229,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
if (mod_->ContainGlobalVar(var->name_hint)) {
// TODO(mbs): Is there a deep reason we were looking up the original types?
// (Mutual recursion should be ok give caching.)
return GetType(mod_->Lookup(var));
return GetLetrecType(mod_->Lookup(var));
} else {
DLOG(INFO) << "var not bound in module!";
return op->checked_type_;
Expand Down

0 comments on commit 9b894d7

Please sign in to comment.