Skip to content

Commit

Permalink
[AMP][Pass][Typing] Add faster type inference (apache#9735)
Browse files Browse the repository at this point in the history
* reuse checked types

* analogous subgraph

* brr go fast

* clean up src logs

* clean up PR more

* more clean up

* more documenetation

* clean up

* formatting

* rename fast --> local

* more ocmments

* jostle ci

* type inference

* change comment for SameTypedSubgraphExtractor

* get_analogous_expression -> GetAnalogousExpression

* comment in GetAnaalogousExpression

* add comment

* replace infer tests

* jostle
  • Loading branch information
AndrewZhaoLuo authored and ylc committed Jan 7, 2022
1 parent 8d803f5 commit d0d4ed8
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 21 deletions.
17 changes: 16 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,29 @@ TVM_DLL Pass DynamicToStatic();
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* The result of type checking is a new expression with unambiguous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
*/
TVM_DLL Pass InferType();

/*!
* \brief Infer the type of an expression, reusing existing type information.
*
* The result of type checking is a new expression with unambiguous
* type information filled in for the given node only. The local
* version can use existing type information populated throughout
* the expression and assumes this information is correct. The local
* version also avoids examining large amounts of the graph assuming
* type information is filled in properly which makes it much faster if we
* iteratively call type inference.
*
* \return The type of the expression.
*/
TVM_DLL Type InferTypeLocal(const Expr& expr);

/*!
* \brief Search and eliminate common subexpression. For example, if there are
* two expressions evaluated to an identical value, a single variable is created
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ def InferType():
return _ffi_api.InferType()


def InferTypeLocal(expr):
"""Infer the type of a single expr, reusing type information to do so.
This populates the checked_type field in expr. We assume existing type information
in the graph is correct!
Parameters
----------
expr: relay.Expr
The expression we want to know the type of
Returns
-------
type: relay.Type
The type of the expression
"""
return _ffi_api.InferTypeLocal(expr)


def FoldScaleAxis():
"""Fold the scaling of axis into weights of conv2d/dense. This pass will
invoke both forward and backward scale folding.
Expand Down
28 changes: 22 additions & 6 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,17 @@ class MixedPrecisionPass : public MixedModeMutator {
}

Type GetType(const Expr& expr) const {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
if (expr.as<FunctionNode>()) {
return mod->Lookup("main")->checked_type();
} else {
return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
// The expression has not been changed AND it's existing type
// is known to still be valid. (See special handling for tuples etc
// below for where we null out checked_type_ when we can not
// sure it is still valid.
Type checked_type = expr->checked_type_;
if (checked_type.defined()) {
return checked_type;
}

// This also populates the checked_type_ field for expr
return transform::InferTypeLocal(expr);
}

bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
Expand Down Expand Up @@ -381,6 +385,18 @@ class MixedPrecisionPass : public MixedModeMutator {
return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
}

Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
return post;
}

Expr Rewrite_(const TupleNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
return post;
}

Expr VisitExpr_(const FunctionNode* func) final {
// Erase the ret_type annotation and let the normal pass recalculate
const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
Expand Down
106 changes: 106 additions & 0 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,114 @@ void AddGlobalTypes(IRModule mod) {
}
}

/*!
* \brief Returns a possibly much smaller subgraph whose inner nodes have the same type.
*
* Returns the largest sub-graph who's inner nodes need types and leaves are vars standing in
* for already typed sub-expressions. This creates a graph whose inner nodes have the same
* type as the original graph and when running type inference, we can avoid copying and
* recursing through most of the expression graph when running type inference. Note, this assumes
* that current populated type information is correct!
*
* ExprMutator is sufficient over MixedModemutator since we will not recurse much.
*/
class SameTypedSubgraphExtractor : public ExprMutator {
Expr VisitExpr_(const VarNode* op) { return Var(op->vid, op->type_annotation, op->span); }
Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, op->span); }
Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
Expr VisitExpr_(const TupleNode* op) {
return Tuple(GetAnalogousExpression(op->fields), op->span);
}
Expr VisitExpr_(const FunctionNode* op) {
// Unfortunately our strategy of inserting variables as dummies would change the signature of
// existing function nodes so we have to copy all used functions always :/
return Function(op->params, op->body, op->ret_type, op->type_params, op->attrs, op->span);
}
Expr VisitExpr_(const CallNode* op) {
return Call(op->op, GetAnalogousExpression(op->args), op->attrs, op->type_args, op->span);
}
Expr VisitExpr_(const LetNode* op) {
return Let(op->var, GetAnalogousExpression(op->value), GetAnalogousExpression(op->body),
op->span);
}
Expr VisitExpr_(const IfNode* op) {
return If(GetAnalogousExpression(op->cond), GetAnalogousExpression(op->true_branch),
GetAnalogousExpression(op->false_branch), op->span);
}
Expr VisitExpr_(const TupleGetItemNode* op) {
return TupleGetItem(GetAnalogousExpression(op->tuple), op->index, op->span);
}
Expr VisitExpr_(const RefCreateNode* op) {
return RefCreate(GetAnalogousExpression(op->value), op->span);
}
Expr VisitExpr_(const RefReadNode* op) {
return RefRead(GetAnalogousExpression(op->ref), op->span);
}
Expr VisitExpr_(const RefWriteNode* op) {
return RefWrite(GetAnalogousExpression(op->ref), GetAnalogousExpression(op->value), op->span);
}
Expr VisitExpr_(const ConstructorNode* op) {
return Constructor(op->name_hint, op->inputs, op->belong_to);
}
Expr VisitExpr_(const MatchNode* op) {
return Match(GetAnalogousExpression(op->data), op->clauses, op->complete, op->span);
}

private:
Expr GetAnalogousExpression(const Expr& expr) {
// Replace the expression with a potentially simpler expression of the same type
if (expr->checked_type_.defined()) {
// Since the expression already has a checked_type which we assume is correct we don't need
// full type inference to enter it. So stub it out with a dummy var of the same type.
return Var("dummy_var", expr->checked_type(), expr->span);
}

return VisitExpr(expr);
}
Array<Expr> GetAnalogousExpression(const Array<Expr>& fields) {
Array<Expr> new_fields;
for (Expr expr : fields) {
new_fields.push_back(GetAnalogousExpression(expr));
}
return new_fields;
}
};

namespace transform {

Type InferTypeLocal(const Expr& expr) {
/*
This type inference differs from InferType in that it uses existing type information
to avoid recursing over much of the graph, and it only examines the type of the input
node. This makes it faster if you need to run type inference iteratively throughout
a pass for example.
However, it assumes any existing populated type inference is correct! If some populated
type inference is incorrect, an incorrect type may be returned or a type error will be
raised. If you know not all populated type fields are correct with the current graph,
you should use InferType() instead.
*/
SameTypedSubgraphExtractor subgraph_extractor;
Expr sub_graph = subgraph_extractor(expr);
auto mod = IRModule::FromExpr(sub_graph);
mod = transform::InferType()(mod);

Type result_type;
if (expr.as<FunctionNode>()) {
result_type = mod->Lookup("main")->checked_type();
} else {
result_type = mod->Lookup("main").as<FunctionNode>()->body->checked_type();
}

expr->checked_type_ = result_type;
return result_type;
}

TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const Expr& expr) {
return InferTypeLocal(expr);
});

Pass InferType() {
auto pass_info = PassInfo(0, "InferType", {});
return tvm::transform::CreateModulePass(
Expand Down
27 changes: 13 additions & 14 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
"""
import pytest
import tvm

from tvm import IRModule, te, relay, parser
from tvm.relay import op, transform, analysis
from tvm import IRModule, parser, relay, te
from tvm.relay import analysis, op, transform
from tvm.relay.op import op as _op


Expand All @@ -33,12 +32,9 @@ def infer_mod(mod, annotate_spans=True):
return mod


def infer_expr(expr, annotate_spans=True):
mod = IRModule.from_expr(expr)
mod = infer_mod(mod, annotate_spans)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def infer_expr(expr):
transform.InferTypeLocal(expr)
return expr


def assert_has_type(expr, typ, mod=None):
Expand Down Expand Up @@ -68,7 +64,7 @@ def test_monomorphic_let():
# TODO(@jroesch): this seems whack.
sb = relay.ScopeBuilder()
x = relay.var("x", dtype="float64", shape=())
x = sb.let("x", relay.const(1.0, "float64"))
x = sb.let(x, relay.const(1.0, "float64"))
sb.ret(x)
xchecked = infer_expr(sb.get())
assert xchecked.checked_type == relay.scalar_type("float64")
Expand Down Expand Up @@ -165,11 +161,11 @@ def @f(%n: int32, %data: float32) -> float32 {
def test_incomplete_call():
tt = relay.scalar_type("int32")
x = relay.var("x", tt)
f_type = relay.FuncType([tt], tt)
f = relay.var("f")
func = relay.Function([x, f], relay.Call(f, [x]), tt)

ft = infer_expr(func)
f_type = relay.FuncType([tt], tt)
assert ft.checked_type == relay.FuncType([tt, f_type], tt)


Expand Down Expand Up @@ -245,7 +241,7 @@ def test_ref():
def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = infer_expr(y, annotate_spans=False)
yy = infer_expr(y)
assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
assert yy.checked_type == relay.scalar_type("float32")
assert x.vid.same_as(yy.args[0].vid)
Expand All @@ -255,8 +251,11 @@ def test_type_args():
x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
ty_z = infer_expr(z)
ty_args = ty_z.type_args

# InferTypeLocal does not support populating the type_args field
mod = infer_mod(IRModule.from_expr(z))
mod = infer_mod(mod, annotate_spans=False)
ty_args = mod["main"].body.type_args
assert len(ty_args) == 2
assert ty_args[0].dtype == "float32"
assert ty_args[1].dtype == "float32"
Expand Down

0 comments on commit d0d4ed8

Please sign in to comment.