Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP][Pass][Typing] Add faster type inference #9735

Merged
merged 19 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,29 @@ TVM_DLL Pass DynamicToStatic();
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* The result of type checking is a new expression with unambiguous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
*/
TVM_DLL Pass InferType();

/*!
* \brief Infer the type of an expression, reusing existing type information.
*
* The result of type checking is a new expression with unambiguous
* type information filled in for the given node only. The local
* version can use existing type information populated throughout
* the expression and assumes this information is correct. The local
* version also avoids examining large amounts of the graph assuming
* type information is filled in properly which makes it much faster if we
* iteratively call type inference.
*
* \return The pass.
*/
TVM_DLL Type InferTypeLocal(const Expr& expr);

/*!
* \brief Search and eliminate common subexpression. For example, if there are
* two expressions evaluated to an identical value, a single variable is created
Expand Down
24 changes: 18 additions & 6 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,13 @@ class MixedPrecisionPass : public MixedModeMutator {
}

Type GetType(const Expr& expr) const {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
if (expr.as<FunctionNode>()) {
return mod->Lookup("main")->checked_type();
} else {
return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
Type checked_type = expr->checked_type_;
if (checked_type.defined()) {
return checked_type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// The expression has not been changed AND it's existing type
// is known to still be valid. (See special handling for tuples etc
// below for where we null out checked_type_ when we can not
// sure it is still valid.

(though see my comment below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

// This also populates the checked_type_ field for expr
return transform::InferTypeLocal(expr);
}

bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
Expand Down Expand Up @@ -381,6 +381,18 @@ class MixedPrecisionPass : public MixedModeMutator {
return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
}

Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I missing something or will checked_type_ = null iff some sub-expression of post has been rewritten and thus it's type has changed?
ie checked_type_ is non-null only if pre == post.get() ??

Copy link
Contributor Author

@AndrewZhaoLuo AndrewZhaoLuo Dec 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so you would think so, but it looks like the mutator does not by default invalidate the checked_type (and appears to reuse the reference? giving us this problem).

I can dig a little deeper, but if I remove this line for TupleGetItemNode the checked type will be wrong (it will be fp32 instead of fp16)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/tvm/blob/main/src/relay/ir/expr_functor.cc#L248

Here is the behavior for generating post, there is some Copy on write stuff which i don't quite understand the full mechanics of so 🤷

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! It's the COW, that makes sense. I think that means we should be clearing checked_type_ on COW but let's not dig ourselves any deeper until we've thought about incremental type inference a bit more carefully.

return post;
}

Expr Rewrite_(const TupleNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
return post;
}

Expr VisitExpr_(const FunctionNode* func) final {
// Erase the ret_type annotation and let the normal pass recalculate
const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
Expand Down
99 changes: 99 additions & 0 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
}
}

class SameTypedSubgraphExtractor : public ExprMutator {
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

micro nit: move to before class, used /*! etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Returns the largest sub-graph who's inner nodes need types and leaves are vars standing in
for already typed sub-expressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Creates a small subgraph with the same type as the input expression. We attempt to do
by depending on existing type information being populated in expressions the target
node depends on. If a node with populated type information is found we simply
replace it with a variable of that type. In this way, we can avoid copying and
recursing through most of the expression graph. Note, this assumes that current
populated type information is correct!

ExprMutator is sufficient over MixedModemutator since we will not recurse much.
*/

Expr VisitExpr_(const VarNode* op) { return Var(op->vid, op->type_annotation, op->span); }
Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, op->span); }
Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
Expr VisitExpr_(const TupleNode* op) {
return Tuple(get_analogous_expression(op->fields), op->span);
}
Expr VisitExpr_(const FunctionNode* op) {
// Here will be the only VisitExpr
return Function(op->params, get_analogous_expression(op->body), op->ret_type, op->type_params,
op->attrs, op->span);
}
Expr VisitExpr_(const CallNode* op) {
return Call(op->op, get_analogous_expression(op->args), op->attrs, op->type_args, op->span);
}
Expr VisitExpr_(const LetNode* op) {
return Let(op->var, get_analogous_expression(op->value), get_analogous_expression(op->body),
op->span);
}
Expr VisitExpr_(const IfNode* op) {
return If(get_analogous_expression(op->cond), get_analogous_expression(op->true_branch),
get_analogous_expression(op->false_branch), op->span);
}
Expr VisitExpr_(const TupleGetItemNode* op) {
return TupleGetItem(get_analogous_expression(op->tuple), op->index, op->span);
}
Expr VisitExpr_(const RefCreateNode* op) {
return RefCreate(get_analogous_expression(op->value), op->span);
}
Expr VisitExpr_(const RefReadNode* op) {
return RefRead(get_analogous_expression(op->ref), op->span);
}
Expr VisitExpr_(const RefWriteNode* op) {
return RefWrite(get_analogous_expression(op->ref), get_analogous_expression(op->value),
op->span);
}
Expr VisitExpr_(const ConstructorNode* op) {
return Constructor(op->name_hint, op->inputs, op->belong_to);
}
Expr VisitExpr_(const MatchNode* op) {
return Match(get_analogous_expression(op->data), op->clauses, op->complete, op->span);
}

private:
Expr get_analogous_expression(const Expr& expr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: GetAnalogousExpression

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// Replace the expression with a potentially simpler expression of the same type
if (!expr->checked_type_.defined()) {
return VisitExpr(expr);
}

return Var("dummy_var", expr->checked_type(), expr->span);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Since the expression already has a checked_type which we trust we don't need
// full type inference to enter it. So stub it out with a dummy var of the same type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
Array<Expr> get_analogous_expression(const Array<Expr>& fields) {
Array<Expr> new_fields;
for (Expr expr : fields) {
new_fields.push_back(get_analogous_expression(expr));
}
return new_fields;
}
};

namespace transform {

Type InferTypeLocal(const Expr& expr) {
/*
This type inference differs from InferType in that it uses existing type information
to avoid recursing over much of the graph, and it only examines the type of the input
node. This makes it faster if you need to run type inference iteratively throughout
a pass for example.

However, it assumes any existing populated type inference is correct! If some populated
type inference is incorrect, an incorrect type may be returned or a type error will be
raised. If you know not all populated type fields are correct with the current graph,
you should use InferType() instead.
*/
SameTypedSubgraphExtractor subgraph_extractor;
auto mod = IRModule::FromExpr(subgraph_extractor(expr));

mod = transform::InferType()(mod);
Type result_type = mod->Lookup("main").as<FunctionNode>()->body->checked_type();

expr->checked_type_ = result_type;
return result_type;
}

TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const Expr& expr) {
return InferTypeLocal(expr);
});

Pass InferType() {
auto pass_info = PassInfo(0, "InferType", {});
return tvm::transform::CreateModulePass(
Expand Down