-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP][Pass][Typing] Add faster type inference #9735
Changes from 11 commits
f6747be
5d3932f
3220c80
7dee27b
08b391a
2021f23
5136b85
dbf3cf6
e9a5f55
f8c5012
5960c5c
f294f63
4f0b03b
8301057
1cb38f1
5aae167
d6f73f2
09fbbe0
faeed08
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -176,13 +176,13 @@ class MixedPrecisionPass : public MixedModeMutator { | |
} | ||
|
||
Type GetType(const Expr& expr) const { | ||
auto mod = IRModule::FromExpr(expr); | ||
mod = transform::InferType()(mod); | ||
if (expr.as<FunctionNode>()) { | ||
return mod->Lookup("main")->checked_type(); | ||
} else { | ||
return mod->Lookup("main").as<FunctionNode>()->body->checked_type(); | ||
Type checked_type = expr->checked_type_; | ||
if (checked_type.defined()) { | ||
return checked_type; | ||
} | ||
|
||
// This also populates the checked_type_ field for expr | ||
return transform::InferTypeLocal(expr); | ||
} | ||
|
||
bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const { | ||
|
@@ -381,6 +381,18 @@ class MixedPrecisionPass : public MixedModeMutator { | |
return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span); | ||
} | ||
|
||
Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { | ||
// The old checked type in the expression may not be valid so clear it | ||
post->checked_type_ = Type(nullptr); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. am I missing something or will checked_type_ = null iff some sub-expression of post has been rewritten and thus it's type has changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm so you would think so, but it looks like the mutator does not by default invalidate the checked_type (and appears to reuse the reference? giving us this problem). I can dig a little deeper, but if I remove this line for TupleGetItemNode the checked type will be wrong (it will be fp32 instead of fp16) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/apache/tvm/blob/main/src/relay/ir/expr_functor.cc#L248 Here is the behavior for generating post, there is some Copy on write stuff which i don't quite understand the full mechanics of so 🤷 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah! It's the COW, that makes sense. I think that means we should be clearing checked_type_ on COW but let's not dig ourselves any deeper until we've thought about incremental type inference a bit more carefully. |
||
return post; | ||
} | ||
|
||
Expr Rewrite_(const TupleNode* pre, const Expr& post) { | ||
// The old checked type in the expression may not be valid so clear it | ||
post->checked_type_ = Type(nullptr); | ||
return post; | ||
} | ||
|
||
Expr VisitExpr_(const FunctionNode* func) final { | ||
// Erase the ret_type annotation and let the normal pass recalculate | ||
const_cast<FunctionNode*>(func)->ret_type = Type(nullptr); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) { | |
} | ||
} | ||
|
||
class SameTypedSubgraphExtractor : public ExprMutator { | ||
/* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. micro nit: move to before class, used /*! etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Returns the largest sub-graph who's inner nodes need types and leaves are vars standing in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
Creates a small subgraph with the same type as the input expression. We attempt to do | ||
by depending on existing type information being populated in expressions the target | ||
node depends on. If a node with populated type information is found we simply | ||
replace it with a variable of that type. In this way, we can avoid copying and | ||
recursing through most of the expression graph. Note, this assumes that current | ||
populated type information is correct! | ||
|
||
ExprMutator is sufficient over MixedModemutator since we will not recurse much. | ||
*/ | ||
|
||
Expr VisitExpr_(const VarNode* op) { return Var(op->vid, op->type_annotation, op->span); } | ||
Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, op->span); } | ||
Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); } | ||
Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); } | ||
Expr VisitExpr_(const TupleNode* op) { | ||
return Tuple(get_analogous_expression(op->fields), op->span); | ||
} | ||
Expr VisitExpr_(const FunctionNode* op) { | ||
// Here will be the only VisitExpr | ||
return Function(op->params, get_analogous_expression(op->body), op->ret_type, op->type_params, | ||
op->attrs, op->span); | ||
} | ||
Expr VisitExpr_(const CallNode* op) { | ||
return Call(op->op, get_analogous_expression(op->args), op->attrs, op->type_args, op->span); | ||
} | ||
Expr VisitExpr_(const LetNode* op) { | ||
return Let(op->var, get_analogous_expression(op->value), get_analogous_expression(op->body), | ||
op->span); | ||
} | ||
Expr VisitExpr_(const IfNode* op) { | ||
return If(get_analogous_expression(op->cond), get_analogous_expression(op->true_branch), | ||
get_analogous_expression(op->false_branch), op->span); | ||
} | ||
Expr VisitExpr_(const TupleGetItemNode* op) { | ||
return TupleGetItem(get_analogous_expression(op->tuple), op->index, op->span); | ||
} | ||
Expr VisitExpr_(const RefCreateNode* op) { | ||
return RefCreate(get_analogous_expression(op->value), op->span); | ||
} | ||
Expr VisitExpr_(const RefReadNode* op) { | ||
return RefRead(get_analogous_expression(op->ref), op->span); | ||
} | ||
Expr VisitExpr_(const RefWriteNode* op) { | ||
return RefWrite(get_analogous_expression(op->ref), get_analogous_expression(op->value), | ||
op->span); | ||
} | ||
Expr VisitExpr_(const ConstructorNode* op) { | ||
return Constructor(op->name_hint, op->inputs, op->belong_to); | ||
} | ||
Expr VisitExpr_(const MatchNode* op) { | ||
return Match(get_analogous_expression(op->data), op->clauses, op->complete, op->span); | ||
} | ||
|
||
private: | ||
Expr get_analogous_expression(const Expr& expr) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: GetAnalogousExpression There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
// Replace the expression with a potentially simpler expression of the same type | ||
if (!expr->checked_type_.defined()) { | ||
return VisitExpr(expr); | ||
} | ||
|
||
return Var("dummy_var", expr->checked_type(), expr->span); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // Since the expression already has a checked_type which we trust we don't need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
Array<Expr> get_analogous_expression(const Array<Expr>& fields) { | ||
Array<Expr> new_fields; | ||
for (Expr expr : fields) { | ||
new_fields.push_back(get_analogous_expression(expr)); | ||
} | ||
return new_fields; | ||
} | ||
}; | ||
|
||
namespace transform { | ||
|
||
Type InferTypeLocal(const Expr& expr) { | ||
/* | ||
This type inference differs from InferType in that it uses existing type information | ||
to avoid recursing over much of the graph, and it only examines the type of the input | ||
node. This makes it faster if you need to run type inference iteratively throughout | ||
a pass for example. | ||
|
||
However, it assumes any existing populated type inference is correct! If some populated | ||
type inference is incorrect, an incorrect type may be returned or a type error will be | ||
raised. If you know not all populated type fields are correct with the current graph, | ||
you should use InferType() instead. | ||
*/ | ||
SameTypedSubgraphExtractor subgraph_extractor; | ||
auto mod = IRModule::FromExpr(subgraph_extractor(expr)); | ||
|
||
mod = transform::InferType()(mod); | ||
Type result_type = mod->Lookup("main").as<FunctionNode>()->body->checked_type(); | ||
|
||
expr->checked_type_ = result_type; | ||
return result_type; | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const Expr& expr) { | ||
return InferTypeLocal(expr); | ||
}); | ||
|
||
Pass InferType() { | ||
auto pass_info = PassInfo(0, "InferType", {}); | ||
return tvm::transform::CreateModulePass( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// The expression has not been changed AND it's existing type
// is known to still be valid. (See special handling for tuples etc
// below for where we null out checked_type_ when we can not
// sure it is still valid.
(though see my comment below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done