-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP][Pass][Typing] Add faster type inference #9735
[AMP][Pass][Typing] Add faster type inference #9735
Conversation
Discussed @jroesch and @mbs-octoml, main changes we want to do is change the name "Fast" --> "Local" and better documentating pre-conditions. |
ac1ce9f
to
5960c5c
Compare
This is now ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never looked at to_mixed_precision.cc before but boy do I see why this would help!
Just some nits, thanks, pretty sure this is going to get more use.
return mod->Lookup("main").as<FunctionNode>()->body->checked_type(); | ||
Type checked_type = expr->checked_type_; | ||
if (checked_type.defined()) { | ||
return checked_type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// The expression has not been changed AND it's existing type
// is known to still be valid. (See special handling for tuples etc
// below for where we null out checked_type_ when we can not
// sure it is still valid.
(though see my comment below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -381,6 +381,18 @@ class MixedPrecisionPass : public MixedModeMutator { | |||
return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span); | |||
} | |||
|
|||
Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { | |||
// The old checked type in the expression may not be valid so clear it | |||
post->checked_type_ = Type(nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
am I missing something or will checked_type_ = null iff some sub-expression of post has been rewritten and thus it's type has changed?
ie checked_type_ is non-null only if pre == post.get() ??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm so you would think so, but it looks like the mutator does not by default invalidate the checked_type (and appears to reuse the reference? giving us this problem).
I can dig a little deeper, but if I remove this line for TupleGetItemNode the checked type will be wrong (it will be fp32 instead of fp16)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/apache/tvm/blob/main/src/relay/ir/expr_functor.cc#L248
Here is the behavior for generating post, there is some Copy on write stuff which i don't quite understand the full mechanics of so 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah! It's the COW, that makes sense. I think that means we should be clearing checked_type_ on COW but let's not dig ourselves any deeper until we've thought about incremental type inference a bit more carefully.
src/relay/transforms/type_infer.cc
Outdated
@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) { | |||
} | |||
} | |||
|
|||
class SameTypedSubgraphExtractor : public ExprMutator { | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
micro nit: move to before class, used /*! etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Returns the largest sub-graph who's inner nodes need types and leaves are vars standing in
for already typed sub-expressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/relay/transforms/type_infer.cc
Outdated
} | ||
|
||
private: | ||
Expr get_analogous_expression(const Expr& expr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: GetAnalogousExpression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/relay/transforms/type_infer.cc
Outdated
return VisitExpr(expr); | ||
} | ||
|
||
return Var("dummy_var", expr->checked_type(), expr->span); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Since the expression already has a checked_type which we trust we don't need
// full type inference to enter it. So stub it out with a dummy var of the same type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Was trying to play around with replacing some type inference in |
Added (or rather replaced) some tests. PTAL @mbs-octoml |
@AndrewZhaoLuo Sorry for later reply. Does this help us to solve ADT problem in our MixedPrecision? Let us imagine we have one fn main():
let %1 = xxx;
let %2 = if (%1) {
let %3: = @func___inference_a(%4, %5, %6)
} else {
let %7: = @func___inference_b(%8, %9)
}; Then we have two subgraph |
@FrozenGene not sure if I understand the concern 😅, global var nodes are just used to reference function calls right? These functions have a known type ahead of time right? |
* reuse checked types * analogous subgraph * brr go fast * clean up src logs * clean up PR more * more clean up * more documenetation * clean up * formatting * rename fast --> local * more ocmments * jostle ci * type inference * change comment for SameTypedSubgraphExtractor * get_analogous_expression -> GetAnalogousExpression * comment in GetAnaalogousExpression * add comment * replace infer tests * jostle
@AndrewZhaoLuo Yes. In fact I saw your pr support global var node, I thought you will leverage it to solve this undo: https://github.com/apache/tvm/blob/main/src/relay/transforms/to_mixed_precision.cc#L297 |
@FrozenGene ah yes, so the type inference will work, but need to think about how to handle it properly for AMP, when I initially wrote AMP I ignored stuff not usually found in most real-life models. It is on list of todos here: #8296 |
* reuse checked types * analogous subgraph * brr go fast * clean up src logs * clean up PR more * more clean up * more documenetation * clean up * formatting * rename fast --> local * more ocmments * jostle ci * type inference * change comment for SameTypedSubgraphExtractor * get_analogous_expression -> GetAnalogousExpression * comment in GetAnaalogousExpression * add comment * replace infer tests * jostle
This PR adds a faster type inference pass which specifically is designed for the Automatic Mixed Precision Pass (AMP). The issue with AMP pass is it uses the existing type inference infrastructure extensively but existing type inference is not designed for the AMP workload.
AMP works by topologically going through the expression graph, replacing nodes with casted versions and using type inference extensively to do this. However, in order to use the type inference we must, for every subgraph build an IRModule and run type inference. The current type inference ignores previously populated type information and essentially repopulates the type fields of the subgraph we are examining. In a situation with N nodes arranged in a linear fashion, for AMP pass we will have N subgraphs we examine. For the
i
-th subgraph we havei
nodes which IRModule and type inference will touch. This essentially means we have O(N^2) runtime at least which is bad.The key issues are therefore:
The solution I came up with is a bit of a hack that let's me avoid rewriting the Type Inference pass (which is super essential and would take a long time to change). Essentially given an expression graph with partially populated type information, we can, given a subgraph, very easily construct an analogous graph which has the same type; we just need to replace nodes with known type information with a constant or variable expression. Doing this means if we are only interested in the type of a single node, we can extract a smaller subgraph with all the needed information to infer type. We then build an IRModule and run standard type inference on this much smaller subgraph.
This has 100x reduction in the AMP pass runtime. arcfaceresnet100 on a 2020 m1 macbook pro went from 20s --> 0.2s for example.