diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 91d5b77eebf9..78dda45b954a 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -262,8 +262,8 @@ TVM_DLL Pass InferType(); * \brief Infer the type of an expression, reusing existing type information. * * The result of type checking is a new expression with unambiguous - * type information filled in for that expression only. The local - * version depends on existing type information populated throughout + * type information filled in for the given node only. The local + * version can use existing type information populated throughout * the expression and assumes this information is correct. The local * version also avoids examining large amounts of the graph assuming * type information is filled in properly which makes it much faster if we diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index a3f68c75cddf..5c3cc4c16f6c 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -830,7 +830,8 @@ class SameTypedSubgraphExtractor : public ExprMutator { by depending on existing type information being populated in expressions the target node depends on. If a node with populated type information is found we simply replace it with a variable of that type. In this way, we can avoid copying and - recursing through most of the expression graph. + recursing through most of the expression graph. Note, this assumes that current + populated type information is correct! ExprMutator is sufficient over MixedModemutator since we will not recurse much. */ @@ -899,6 +900,17 @@ class SameTypedSubgraphExtractor : public ExprMutator { namespace transform { Type InferTypeLocal(const Expr& expr) { + /* + This type inference differs from InferType in that it uses existing type information + to avoid recursing over much of the graph, and it only examines the type of the input + node. This makes it faster if you need to run type inference iteratively throughout + a pass for example. + + However, it assumes any existing populated type inference is correct! If some populated + type inference is incorrect, an incorrect type may be returned or a type error will be + raised. If you know not all populated type fields are correct with the current graph, + you should use InferType() instead. + */ SameTypedSubgraphExtractor subgraph_extractor; auto mod = IRModule::FromExpr(subgraph_extractor(expr));