From 9b894d765c961b4ad0198a59808a373cf3b0d051 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Thu, 29 Jul 2021 17:10:24 -0700 Subject: [PATCH] [checkpoint] Fix GetType recursion, get debug going. --- src/relay/backend/interpreter.cc | 6 +++++- src/relay/backend/te_compiler.cc | 33 ++++++++++++++---------------- src/relay/transforms/type_infer.cc | 21 ++++++++++++++++++- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 4bf2c4bd49185..5c6d96a9a6738 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -936,7 +936,7 @@ void handler(int sig) { * * If not we can evaluate it directly and don't need to bind it into a fresh module. */ -class NeedsPreparationVisitor : private ExprVisitor { +class NeedsPreparationVisitor : public ExprVisitor { public: bool needs_preparation = false; @@ -964,7 +964,11 @@ ObjectRef Interpret(IRModule mod, Expr expr, Device device, Target target) { signal(SIGBUS, handler); DLOG(INFO) << "interpreting:\n" << expr << "\nw.r.t. module:\n" << mod; + + // If expr is simple enough we can avoid binding it into the module. NeedsPreparationVisitor visitor; + visitor.VisitExpr(expr); + Expr expr_to_eval; IRModule mod_and_expr; // default empty if (visitor.needs_preparation) { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 328ebf35978b2..6ac3e7d926e21 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -410,7 +410,8 @@ class LowerTensorExprMutator : public ExprMutator { device_context_map_(device_ctx_map), process_fn_(process_fn), module_name_(module_name), - compiler_(compiler) {} + compiler_(compiler), + debug_op_(Op::Get("debug")) {} /*! * \brief Returns the primitive function associated with \p expr, or @@ -419,18 +420,7 @@ class LowerTensorExprMutator : public ExprMutator { Function ResolveToPrimitive(Expr expr) { if (const GlobalVarNode* gvn = expr.as()) { BaseFunc base_func = module_->Lookup(GetRef(gvn)); - if (const FunctionNode* fn = base_func.as()) { - if (!fn->HasNonzeroAttr(attr::kPrimitive)) { - DLOG(INFO) << "ignoring, global var is not bound to a primitive function"; - return Function(); - } else { - DLOG(INFO) << "global var bound to prim function"; - return GetRef(fn); - } - } else { - DLOG(INFO) << "ignoring, global var is not bound to a Relay function"; - return Function(); - } + return ResolveToPrimitive(base_func); } else if (const VarNode* vn = expr.as()) { auto itr = primitive_functions_.find(GetRef(vn)); if (itr == primitive_functions_.end()) { @@ -443,10 +433,15 @@ class LowerTensorExprMutator : public ExprMutator { if (!fn->HasNonzeroAttr(attr::kPrimitive)) { DLOG(INFO) << "ignoring, function is not primitive"; return Function(); - } else { - DLOG(INFO) << "prim function"; - return GetRef(fn); } + if (const CallNode* cn = fn->body.as()) { + if (cn->op == debug_op_) { + DLOG(INFO) << "ignoring, primitive function is debug function"; + return Function(); + } + } + DLOG(INFO) << "prim function"; + return GetRef(fn); } DLOG(INFO) << "ignoring other expression"; return Function(); @@ -454,9 +449,9 @@ class LowerTensorExprMutator : public ExprMutator { /*! * \brief Lowers the primitive function \p func to TIR for ultimate execution - * on a target with \p device_type. Returns the global var bound to the TIR + * on a target of \p device_type. Returns the global var bound to the TIR * implementation and attributes to attach to the call to identify it as - * a primitive call. + * a TIR call. */ std::pair LowerFunction(Function func, DLDeviceType device_type) { Target target; @@ -627,6 +622,8 @@ class LowerTensorExprMutator : public ExprMutator { primitive_functions_; String module_name_; TECompiler compiler_; + // Cache ops that need to be frequently used later to reduce lookup overhead. + const Op& debug_op_; }; Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 9e6388a4cc437..0e150111f677d 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -188,6 +188,25 @@ class TypeInferencer : private ExprFunctor, return ret; } + // As above, but for a func which may be mutually recursive with itself. We use the + // functions type signature and find a fixed point in one step. + Type GetLetrecType(const BaseFunc& func) { + auto it = type_map_.find(func); + if (it != type_map_.end() && it->second.checked_type.defined()) { + return it->second.checked_type; + } + if (func->checked_type_.defined()) { + ResolvedTypeInfo& rti = type_map_[func]; + rti.checked_type = func->checked_type_; + } + Type ret = this->VisitExpr(func); + ICHECK(ret.defined()); + KindCheck(ret, mod_, this->diag_ctx); + ResolvedTypeInfo& rti = type_map_[func]; + rti.checked_type = ret; + return ret; + } + void EmitFatal(const Diagnostic& diag) { this->diag_ctx.EmitFatal(diag); } // Visitor Logic @@ -210,7 +229,7 @@ class TypeInferencer : private ExprFunctor, if (mod_->ContainGlobalVar(var->name_hint)) { // TODO(mbs): Is there a deep reason we were looking up the original types? // (Mutual recursion should be ok give caching.) - return GetType(mod_->Lookup(var)); + return GetLetrecType(mod_->Lookup(var)); } else { DLOG(INFO) << "var not bound in module!"; return op->checked_type_;